Source code for domainlab.algos.msels.c_msel_tr_loss

"""
AMSel.accept ---> Trainer
"""
import math

from domainlab.algos.msels.a_model_sel import AMSel
from domainlab.utils.logger import Logger


[docs] class MSelTrLoss(AMSel): """ 1. Model selection using sum of loss across training domains 2. Visitor pattern to trainer """ def __init__(self, max_es, val_threshold = None): super().__init__(val_threshold) # NOTE: super() must come first otherwise it will overwrite existing # values! self.reset() self._max_es = max_es
[docs] def reset(self): self.best_loss = float("inf") self.es_c = 0
@property def max_es(self): return self._max_es
[docs] def base_update(self, clear_counter=False): """ if the best model should be updated """ loss = self.trainer.epo_loss_tr # @FIXME assert loss is not None assert not math.isnan(loss) flag = True if loss < self.best_loss: self.es_c = 0 # restore counter self.best_loss = loss else: self.es_c += 1 logger = Logger.get_logger() logger.info(f"early stop counter: {self.es_c}") logger.info(f"loss:{loss}, best loss: {self.best_loss}") flag = False # do not update best model if clear_counter: logger.info("clearing counter") self.es_c = 0 return flag
[docs] def early_stop(self): """ if should early stop """ return self.es_c >= self.max_es