Source code for domainlab.algos.observers.b_obvisitor

"""
observer and visitor pattern, responsible train, validation, test
dispatch performance evaluation to model,
dispatch model selection to model selection object
"""
import os

from domainlab.algos.observers.a_observer import AObVisitor
from domainlab.tasks.task_folder_mk import NodeTaskFolderClassNaMismatch
from domainlab.tasks.task_pathlist import NodeTaskPathListDummy
from domainlab.utils.logger import Logger


[docs] class ObVisitor(AObVisitor): """ Observer + Visitor pattern for model selection """ def __init__(self, model_sel): """ observer trainer """ super().__init__() self.host_trainer = None self.model_sel = model_sel self.epo = None self.metric_te = None self.metric_val = None self.perf_metric = None @property def str_metric4msel(self): """ string representing the metric used for persisting models on the disk """ return self.host_trainer.str_metric4msel
[docs] def update(self, epoch): logger = Logger.get_logger() logger.info(f"epoch: {epoch}") self.epo = epoch if epoch % self.epo_te == 0: logger.info("---- Training Domain: ") self.host_trainer.model.cal_perf_metric(self.loader_tr, self.device) if self.loader_val is not None: logger.info("---- Validation: ") self.metric_val = self.host_trainer.model.cal_perf_metric( self.loader_val, self.device ) if self.loader_te is not None: logger.info("---- Test Domain (oracle): ") metric_te = self.host_trainer.model.cal_perf_metric( self.loader_te, self.device ) self.metric_te = metric_te if self.model_sel.update(epoch): logger.info("better model found") self.host_trainer.model.save() logger.info("persisted") acc = self.metric_te.get("acc") flag_stop = self.model_sel.if_stop(acc) flag_enough = epoch >= self.host_trainer.aconf.epos_min return flag_stop & flag_enough
[docs] def accept(self, trainer): """ accept invitation as a visitor """ self.host_trainer = trainer self.model_sel.accept(trainer, self) self.set_task(trainer.task, args=trainer.aconf, device=trainer.device) self.perf_metric = self.host_trainer.model.create_perf_obj(self.task)
[docs] def after_all(self): """ After training is done """ model_ld = None try: model_ld = self.host_trainer.model.load() except FileNotFoundError as err: # if other errors/exceptions occur, we do not catch them # other exceptions will terminate the python script # this can happen if loss is increasing, model never get selected logger = Logger.get_logger() logger.warning(err) logger.warning( "this error can occur if model selection criteria \ is worsening, " "model never get persisted, \ no performance metric is reported" ) return model_ld = model_ld.to(self.device) model_ld.eval() logger = Logger.get_logger() logger.info("persisted model performance metric: \n") metric_te = model_ld.cal_perf_metric(self.loader_te, self.device) dict_2add = self.cal_oracle_perf() if dict_2add is not None: metric_te.update(dict_2add) else: metric_te.update({"acc_oracle": -1}) if hasattr(self, "model_sel"): metric_te.update({"acc_val": self.model_sel.best_val_acc}) metric_te.update({"model_selection_epoch": self.model_sel.model_selection_epoch}) else: metric_te.update({"acc_val": -1}) metric_te.update({"model_selection_epoch": -1}) self.dump_prediction(model_ld, metric_te) # save metric to one line in csv result file self.host_trainer.model.visitor(metric_te)
# prediction dump of test domain is essential # to verify the prediction results
[docs] def cal_oracle_perf(self): """ calculate oracle performance """ try: model_or = self.host_trainer.model.load("oracle") # @FIXME: name "oracle is a strong dependency model_or = model_or.to(self.device) model_or.eval() except FileNotFoundError: return {"acc_oracle": -1} logger = Logger.get_logger() logger.info("oracle model performance metric: \n") metric_te = model_or.cal_perf_metric(self.loader_te, self.device) return {"acc_oracle": metric_te["acc"]}
[docs] def dump_prediction(self, model_ld, metric_te): """ given the test domain loader, use the loaded model \ model_ld to predict each instance """ flag_task_folder = isinstance( self.host_trainer.task, NodeTaskFolderClassNaMismatch ) flag_task_path_list = isinstance(self.host_trainer.task, NodeTaskPathListDummy) if flag_task_folder or flag_task_path_list: fname4model = ( self.host_trainer.model.visitor.model_path ) # pylint: disable=E1101 file_prefix = os.path.splitext(fname4model)[0] # remove ".model" dir4preds = os.path.join(self.host_trainer.aconf.out, "saved_predicts") if not os.path.exists(dir4preds): os.mkdir(dir4preds) file_prefix = os.path.join(dir4preds, os.path.basename(file_prefix)) file_name = file_prefix + "_instance_wise_predictions.txt" model_ld.pred2file( self.loader_te, self.device, filename=file_name, metric_te=metric_te )
[docs] def clean_up(self): """ to be called by a decorator """ if not self.keep_model: self.host_trainer.model.visitor.clean_up()