Source code for domainlab.algos.trainers.train_irm

"""
use random start to generate adversarial images
"""
import torch
from torch import autograd
from torch.nn import functional as F
from domainlab.algos.trainers.train_basic import TrainerBasic


[docs] class TrainerIRM(TrainerBasic): """ IRMv1 split a minibatch into half, and use an unbiased estimate of the squared gradient norm via inner product $$\\delta_{w|w=1} \\ell(w\\dot \\Phi(X^{e, i}), Y^{e, i})$$ of dimension dim(Grad) with $$\\delta_{w|w=1} \\ell(w\\dot \\Phi(X^{e, j}), Y^{e, j})$$ of dimension dim(Grad) For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.” """
[docs] def tr_epoch(self, epoch): list_loaders = list(self.dict_loader_tr.values()) loaders_zip = zip(*list_loaders) self.model.train() self.epo_loss_tr = 0 for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip): self.optimizer.zero_grad() list_domain_loss_erm = [] list_domain_reg = [] for batch_domain_e in tuple_data_domains_batch: tensor_x, tensor_y, tensor_d, *others = batch_domain_e tensor_x, tensor_y, tensor_d = \ tensor_x.to(self.device), tensor_y.to(self.device), \ tensor_d.to(self.device) list_domain_loss_erm.append( self.model.cal_task_loss(tensor_x, tensor_y)) list_1ele_loss_irm, _ = \ self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others) list_domain_reg += list_1ele_loss_irm loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) loss.backward() self.optimizer.step() self.epo_loss_tr += loss.detach().item() self.after_batch(epoch, ind_batch) flag_stop = self.observer.update(epoch) # notify observer return flag_stop
def _cal_phi(self, tensor_x): logits = self.model.cal_logit_y(tensor_x) return logits def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ Let trainer behave like a model, so that other trainer could use it """ _ = tensor_d _ = others y = tensor_y phi = self._cal_phi(tensor_x) dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) return [loss_irm_tensor], [self.aconf.gamma_reg]