Source code for domainlab.algos.trainers.train_mldg

"""
Meta Learning Domain Generalization
"""
import copy
import random

from torch.utils.data.dataset import ConcatDataset

from domainlab.algos.trainers.a_trainer import AbstractTrainer
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.tasks.utils_task import mk_loader
from domainlab.tasks.utils_task_dset import DsetZip
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


[docs] class TrainerMLDG(AbstractTrainer): """ basic trainer """
[docs] def before_tr(self): """ check the performance of randomly initialized weight """ self.model.evaluate(self.loader_te, self.device) self.inner_trainer = TrainerBasic() self.inner_trainer.extend(self._decoratee) inner_model = copy.deepcopy(self.model) self.inner_trainer.init_business( inner_model, copy.deepcopy(self.task), self.observer, self.device, self.aconf, flag_accept=False, ) self.prepare_ziped_loader() super().before_tr()
[docs] def prepare_ziped_loader(self): """ create virtual source and target domain """ list_dsets = list(self.task.dict_dset_tr.values()) num_domains = len(list_dsets) ind_target_domain = random.randrange(num_domains) tuple_dsets_source = tuple( list_dsets[ind] for ind in range(num_domains) if ind != ind_target_domain ) ddset_source = ConcatDataset(tuple_dsets_source) ddset_target = list_dsets[ind_target_domain] ddset_mix = DsetZip(ddset_source, ddset_target) self.loader_tr_source_target = mk_loader(ddset_mix, self.aconf.bs)
[docs] def tr_epoch(self, epoch): self.model.train() self.epo_loss_tr = 0 self.prepare_ziped_loader() # s means source, t means target for ind_batch, ( tensor_x_s, vec_y_s, vec_d_s, others_s, tensor_x_t, vec_y_t, vec_d_t, *_, ) in enumerate(self.loader_tr_source_target): tensor_x_s, vec_y_s, vec_d_s = ( tensor_x_s.to(self.device), vec_y_s.to(self.device), vec_d_s.to(self.device), ) tensor_x_t, vec_y_t, vec_d_t = ( tensor_x_t.to(self.device), vec_y_t.to(self.device), vec_d_t.to(self.device), ) self.optimizer.zero_grad() self.inner_trainer.model.load_state_dict(self.model.state_dict()) # update inner_model self.inner_trainer.before_epoch() # set model to train mode self.inner_trainer.reset() # force optimizer to re-initialize self.inner_trainer.tr_batch( tensor_x_s, vec_y_s, vec_d_s, others_s, ind_batch, epoch ) # inner_model has now accumulated gradients Gi # with parameters theta_i - lr * G_i where i index batch loss_look_forward = self.inner_trainer.model.cal_task_loss( tensor_x_t, vec_y_t ) loss_source_task = self.model.cal_task_loss(tensor_x_s, vec_y_s) list_source_reg_tr, list_source_mu_tr = self.cal_reg_loss( tensor_x_s, vec_y_s, vec_d_s, others_s ) # call cal_reg_loss from decoratee # super()._cal_reg_loss returns [],[], # since mldg's reg loss is on target domain, # no other trainer except hyperscheduler could decorate it unless we use state pattern # in the future to control source and target domain loader behavior source_reg_tr = self.model.list_inner_product( list_source_reg_tr, list_source_mu_tr ) # self.aconf.gamma_reg * loss_look_forward.sum() loss = ( loss_source_task.sum() + source_reg_tr.sum() + get_gamma_reg(self.aconf, self.name) * loss_look_forward.sum() ) # loss.backward() # optimizer only optimize parameters of self.model, not inner_model self.optimizer.step() self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) flag_stop = self.observer.update(epoch) # notify observer return flag_stop