Source code for domainlab.algos.trainers.train_dial

"""
use random start to generate adversarial images
"""
import torch
from torch.autograd import Variable

from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


[docs] class TrainerDIAL(TrainerBasic): """ Trainer Domain Invariant Adversarial Learning """
[docs] def gen_adversarial(self, device, img_natural, vec_y): """ use naive trimming to find optimize img in the direction of adversarial gradient, this is not necessarily constraint optimal due to nonlinearity, as the constraint epsilon is only considered ad-hoc """ # ensure adversarial image not in computational graph steps_perturb = self.aconf.dial_steps_perturb scale = self.aconf.dial_noise_scale step_size = self.aconf.dial_lr epsilon = self.aconf.dial_epsilon img_adv_ini = img_natural.detach() img_adv_ini = ( img_adv_ini + scale * torch.randn(img_natural.shape).to(device).detach() ) img_adv = img_adv_ini for _ in range(steps_perturb): img_adv.requires_grad_() loss_gen_adv = self.model.cal_loss_gen_adv(img_natural, img_adv, vec_y) grad = torch.autograd.grad(loss_gen_adv, [img_adv])[0] # instead of gradient descent, we gradient ascent here img_adv = img_adv_ini.detach() + step_size * torch.sign(grad.detach()) img_adv = torch.min( torch.max(img_adv, img_natural - epsilon), img_natural + epsilon ) img_adv = torch.clamp(img_adv, 0.0, 1.0) return img_adv
def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ Let trainer behave like a model, so that other trainer could use it """ _ = tensor_d _ = others tensor_x_adv = self.gen_adversarial(self.device, tensor_x, tensor_y) tensor_x_batch_adv_no_grad = Variable(tensor_x_adv, requires_grad=False) loss_dial = self.model.cal_task_loss(tensor_x_batch_adv_no_grad, tensor_y) return [loss_dial], [get_gamma_reg(self.aconf, self.name)]