Source code for domainlab.algos.msels.c_msel_oracle

"""
Model Selection should be decoupled from
"""
from domainlab.algos.msels.a_model_sel import AMSel
from domainlab.utils.logger import Logger


[docs] class MSelOracleVisitor(AMSel): """ save best out-of-domain test acc model, but do not affect how the final model is selected """ def __init__(self, msel=None, val_threshold = None): """ Decorator pattern """ super().__init__(val_threshold) self.best_oracle_acc = 0 self.msel = msel @property def oracle_last_setpoint_sel_te_acc(self): """ last setpoint acc """ if self.msel is not None and hasattr( self.msel, "oracle_last_setpoint_sel_te_acc" ): return self.msel.oracle_last_setpoint_sel_te_acc return -1
[docs] def base_update(self, clear_counter=False): """ if the best model should be updated """ self.trainer.model.save("epoch") flag = False if self.observer4msel.metric_val is None: return super().base_update(clear_counter) metric = self.observer4msel.metric_te[self.observer4msel.str_metric4msel] if metric > self.best_oracle_acc: self.best_oracle_acc = metric if self.msel is not None: self.trainer.model.save("oracle") else: self.trainer.model.save() logger = Logger.get_logger() logger.info("new oracle model saved") flag = True if self.msel is not None: return self.msel.base_update(clear_counter) return flag
[docs] def early_stop(self): """ if should early stop oracle model selection does not intervene how models get selected by the innermost model selection """ if self.msel is not None: return self.msel.early_stop() return False
[docs] def accept(self, trainer, observer4msel): if self.msel is not None: self.msel.accept(trainer, observer4msel) super().accept(trainer, observer4msel)