Source code for domainlab.algos.trainers.a_trainer

"""
Base Class for trainer
"""
import abc

import torch
from torch import optim

from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler


[docs] def mk_opt(model, aconf): """ create optimizer """ if model._decoratee is None: optimizer = optim.Adam(model.parameters(), lr=aconf.lr) else: var1 = model.parameters() var2 = model._decoratee.parameters() set_param = set(list(var1) + list(var2)) list_par = list(set_param) # optimizer = optim.Adam([var1, var2], lr=aconf.lr) # optimizer = optim.Adam([ # {'params': model.parameters()}, # {'params': model._decoratee.parameters()} # ], lr=aconf.lr) optimizer = optim.Adam(list_par, lr=aconf.lr) return optimizer
[docs] class AbstractTrainer(AbstractChainNodeHandler, metaclass=abc.ABCMeta): """ Algorithm director that controls the data flow """ @property def p_na_prefix(self): """ common prefix for Trainers """ return "Trainer"
[docs] def extend(self, trainer): """ extend current trainer with another trainer """ self._decoratee = trainer
def __init__(self, successor_node=None, extend=None): """__init__. :param successor_node: """ super().__init__(successor_node) self._model = None self._decoratee = extend self.task = None self.observer = None self.device = None self.aconf = None # self.dict_loader_tr = None self.loader_tr = None self.loader_te = None self.num_batches = None self.flag_update_hyper_per_epoch = None self.flag_update_hyper_per_batch = None self.epo_loss_tr = None self.epo_reg_loss_tr = None self.epo_task_loss_tr = None self.counter_batch = None self.hyper_scheduler = None self.optimizer = None self.exp = None # matchdg self.lambda_ctr = None self.flag_stop = None self.flag_erm = None self.tensor_ref_domain2each_domain_x = None self.tensor_ref_domain2each_domain_y = None self.base_domain_size = None self.tuple_tensor_ref_domain2each_y = None self.tuple_tensor_refdomain2each = None # mldg self.inner_trainer = None self.loader_tr_source_target = None self.flag_initialized = False # moving average self.ma_weight_previous_model_params = None self._dict_previous_para_persist = {} self._ma_iter = 0 # self.list_reg_over_task_ratio = None # MIRO self.input_tensor_shape = None @property def model(self): """ property model, which can be another trainer or model """ return self.get_model() @model.setter def model(self, model): self._model = model @property def str_metric4msel(self): """ metric for model selection """ return self.model.metric4msel @property def list_tr_domain_size(self): """ get a list of training domain size """ train_domains = self.task.list_domain_tr return [len(self.task.dict_dset_tr[key]) for key in train_domains] @property def decoratee(self): if self._decoratee is None: return self.model return self._decoratee
[docs] def init_business(self, model, task, observer, device, aconf, flag_accept=True): """ model, task, observer, device, aconf """ # Note self.decoratee can be both model and trainer, # but self._decoratee can only be trainer! if self._decoratee is not None: self._decoratee.init_business( model, task, observer, device, aconf, flag_accept ) self.model = self._decoratee else: self.model = model self.task = task self.task.init_business(trainer=self, args=aconf) self.model.list_d_tr = self.task.list_domain_tr self.observer = observer self.device = device self.aconf = aconf # self.dict_loader_tr = task.dict_loader_tr self.loader_tr = task.loader_tr self.loader_te = task.loader_te if flag_accept: self.observer.accept(self) self.model = self.model.to(device) # self.num_batches = len(self.loader_tr) self.flag_update_hyper_per_epoch = False self.flag_update_hyper_per_batch = False self.epo_loss_tr = None self.hyper_scheduler = None self.reset() self.flag_initialized = True
[docs] def reset(self): """ make a new optimizer to clear internal state """ self.optimizer = mk_opt(self.model, self.aconf)
[docs] @abc.abstractmethod def tr_epoch(self, epoch): """ :param epoch: """
[docs] def before_batch(self, epoch, ind_batch): """ :param epoch: :param ind_batch: """ return
[docs] def after_batch(self, epoch, ind_batch): """ :param epoch: :param ind_batch: """ return
[docs] def before_tr(self): """ before training, probe model performance """ self.cal_reg_loss_over_task_loss_ratio()
[docs] def cal_reg_loss_over_task_loss_ratio(self): """ estimate the scale of each loss term, match each loss term to the major loss via a ratio, this ratio will be multiplied with multiplier """ list_accum_reg_loss = [] loss_task_agg = 0 for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): self.input_tensor_shape = tensor_x.shape if ind_batch >= self.aconf.nb4reg_over_task_ratio: return tensor_x, tensor_y, tensor_d = ( tensor_x.to(self.device), tensor_y.to(self.device), tensor_d.to(self.device), ) list_reg_loss_tensor, _ = \ self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) list_reg_loss_tensor = [torch.sum(tensor).detach().item() for tensor in list_reg_loss_tensor] if ind_batch == 0: list_accum_reg_loss = list_reg_loss_tensor else: list_accum_reg_loss = [reg_loss_accum_tensor + reg_loss_tensor for reg_loss_accum_tensor, reg_loss_tensor in zip(list_accum_reg_loss, list_reg_loss_tensor)] tensor_loss_task = self.model.cal_task_loss(tensor_x, tensor_y) tensor_loss_task = torch.sum(tensor_loss_task).detach().item() loss_task_agg += tensor_loss_task self.list_reg_over_task_ratio = [reg_loss / loss_task_agg for reg_loss in list_accum_reg_loss]
[docs] def post_tr(self): """ after training """ self.observer.after_all()
@property def name(self): """ get the name of the algorithm """ na_prefix = self.p_na_prefix len_prefix = len(na_prefix) na_class = type(self).__name__ if na_class[:len_prefix] != na_prefix: raise RuntimeError( "Trainer builder node class must start with ", na_prefix, "the current class is named: ", na_class, ) return type(self).__name__[len_prefix:].lower()
[docs] def is_myjob(self, request): """ :param request: string """ return request == self.name
[docs] def get_model(self): """ recursively get the "real" model from trainer """ if "trainer" not in str(type(self._model)).lower(): return self._model return self._model.get_model()
[docs] def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ decorate trainer regularization loss combine losses of current trainer with self._model.cal_reg_loss, which can be either a trainer or a model """ list_reg_loss_model_tensor, list_mu_model = \ self.decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) assert len(list_reg_loss_model_tensor) == len(list_mu_model) list_reg_loss_trainer_tensor, list_mu_trainer = self._cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) assert len(list_reg_loss_trainer_tensor) == len(list_mu_trainer) # extend the length of list: extend number of regularization loss # tensor: the element of list is tensor list_loss_tensor = list_reg_loss_model_tensor + \ list_reg_loss_trainer_tensor list_mu = list_mu_model + list_mu_trainer return list_loss_tensor, list_mu
def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ interface for each trainer to implement """ return [], []
[docs] def dset_decoration_args_algo(self, args, ddset): """ decorate dataset to get extra entries in load item, for instance, jigen need permutation index this parent class function delegate decoration to its decoratee """ if self._decoratee is not None: return self._decoratee.dset_decoration_args_algo(args, ddset) return ddset
[docs] def print_parameters(self): """ Function to print all parameters of the object. Can be used to print the parameters of any child class """ params = vars(self) print(f"Parameters of {type(self).__name__}: {params}")