Source code for domainlab.algos.msels.a_model_sel

"""
Abstract Model Selection
"""

import abc


[docs] class AMSel(metaclass=abc.ABCMeta): """ Abstract Model Selection """ def __init__(self, val_threshold = None): """ trainer and tr_observer """ self.trainer = None self._observer = None self.msel = None self._max_es = None self._model_selection_epoch = None self._val_threshold = val_threshold
[docs] def reset(self): """ reset observer via reset model selector """ if self.msel is not None: self.msel.reset()
@property def observer4msel(self): """ the observer from trainer """ return self._observer @property def max_es(self): """ maximum early stop """ if self._max_es is not None: return self._max_es if self.msel is not None: return self.msel.max_es return self._max_es
[docs] def accept(self, trainer, observer4msel): """ Visitor pattern to trainer accept trainer and tr_observer """ self.trainer = trainer self._observer = observer4msel if self.msel is not None: self.msel.accept(trainer, observer4msel)
[docs] def update(self, epoch, clear_counter=False): """ level above the observer + visitor pattern to get information about the epoch """ update = self.base_update(clear_counter) if update: self._model_selection_epoch = epoch return update
[docs] @abc.abstractmethod def base_update(self, clear_counter=False): """ observer + visitor pattern to trainer if the best model should be updated return boolean """
[docs] def if_stop(self, acc_val = None): """ check if trainer should stop and additionally tests for validation threshold return boolean """ # NOTE: since if_stop is not abstract, one has to # be careful to always override it in child class # only if the child class has a decorator which will # dispatched. if self.msel is not None and acc_val is not None: if self._val_threshold is not None and acc_val < self._val_threshold: return False return self.early_stop()
[docs] def early_stop(self): """ check if trainer should stop return boolean """ if self.msel is not None: return self.msel.early_stop() raise NotImplementedError
@property def best_val_acc(self): """ decoratee best val acc """ if self.msel is not None: return self.msel.best_val_acc return -1 @property def best_te_metric(self): """ decoratee best test metric """ if self.msel is not None: return self.msel.best_te_metric return -1 @property def sel_model_te_acc(self): """ the selected model test accuaracy """ if self.msel is not None: return self.msel.sel_model_te_acc return -1 @property def model_selection_epoch(self): """ the epoch when the model was selected """ if self._model_selection_epoch is not None: return self._model_selection_epoch return -1 @property def val_threshold(self): """ the treshold below which we don't stop early """ return self._val_threshold