Source code for domainlab.algos.msels.c_msel_val
"""
Model Selection should be decoupled from
"""
from domainlab.algos.msels.c_msel_tr_loss import MSelTrLoss
from domainlab.utils.logger import Logger
[docs]
class MSelValPerf(MSelTrLoss):
"""
1. Model selection using validation performance
2. Visitor pattern to trainer
"""
def __init__(self, max_es, val_threshold = None):
super().__init__(max_es, val_threshold) # construct self.observer4msel (observer)
self.reset()
[docs]
def reset(self):
super().reset()
self._best_val_acc = 0.0
self._sel_model_te_acc = 0.0
self._best_te_metric = 0.0
@property
def sel_model_te_acc(self):
return self._sel_model_te_acc
@property
def best_val_acc(self):
"""
decoratee best val acc
"""
return self._best_val_acc
@property
def best_te_metric(self):
"""
decoratee best test metric
"""
return self._best_te_metric
[docs]
def base_update(self, clear_counter=False):
"""
if the best model should be updated
"""
flag = True
if self.observer4msel.metric_val is None:
return super().base_update(clear_counter)
metric = self.observer4msel.metric_val[self.observer4msel.str_metric4msel]
if self.observer4msel.metric_te is not None:
metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel]
self._best_te_metric = max(self._best_te_metric, metric_te_current)
if metric > self._best_val_acc: # update hat{model}
# different from loss, accuracy should be improved:
# the bigger the better
self._best_val_acc = metric
self.es_c = 0 # restore counter
if self.observer4msel.metric_te is not None:
metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel]
self._sel_model_te_acc = metric_te_current
else:
self.es_c += 1
logger = Logger.get_logger()
logger.info(f"early stop counter: {self.es_c}")
logger.info(
f"val acc:{self.observer4msel.metric_val['acc']}, "
+ f"best validation acc: {self.best_val_acc}, "
+ f"corresponding to test acc: \
{self.sel_model_te_acc} / {self.best_te_metric}"
)
flag = False # do not update best model
if clear_counter:
logger.info("clearing counter")
self.es_c = 0
return flag