Source code for domainlab.models.model_dann

"""
construct feature extractor, task neural network (e.g. classification) and domain classification
network
"""
from torch.nn import functional as F

from domainlab import g_str_cross_entropy_agg
from domainlab.compos.nn_zoo.net_adversarial import AutoGradFunReverseMultiply
from domainlab.models.a_model_classif import AModelClassif


[docs] def mk_dann(parent_class=AModelClassif, **kwargs): """ Instantiate a Deep Adversarial Net (DAN) model Details: The model is trained to solve two tasks: 1. Standard image classification. 2. Domain classification. Here for, a feature extractor is adversarially trained to minimize the loss of the image classifier and maximize the loss of the domain classifier. For more details, see: Ganin, Yaroslav, et al. "Domain-adversarial training of neural networks." The journal of machine learning research 17.1 (2016): 2096-2030. Args: parent_class (AModel, optional): Class object determining the task type. Defaults to AModelClassif. Returns: ModelDAN: model inheriting from parent class Input Parameters: list_str_y: list of labels, list_d_tr: list of training domains alpha: total_loss = task_loss + $$\\alpha$$ * domain_classification_loss, net_encoder: neural network to extract the features (input: training data), net_classifier: neural network (input: output of net_encoder; output: label prediction), net_discriminator: neural network (input: output of net_encoder; output: prediction of training domain) Usage: For a concrete example, see: https://github.com/marrlab/DomainLab/blob/master/tests/test_mk_exp_dann.py """ class ModelDAN(parent_class): """ anonymous """ def __init__( self, list_d_tr, alpha, net_encoder, net_discriminator, builder=None, ): """ See documentation above in mk_dann() function """ super().__init__(**kwargs) self.list_d_tr = list_d_tr self.alpha = alpha self._net_invar_feat = net_encoder self.net_discriminator = net_discriminator self.builder = builder def reset_aux_net(self): """ reset auxilliary neural network: domain classifier """ if self.builder is None: return self.net_discriminator = self.builder.reset_aux_net( self.extract_semantic_feat ) def hyper_update(self, epoch, fun_scheduler): """hyper_update. :param epoch: :param fun_scheduler: the hyperparameter scheduler object """ dict_rst = fun_scheduler( epoch ) # the __call__ method of hyperparameter scheduler self.alpha = dict_rst[self.name + "_alpha"] def hyper_init(self, functor_scheduler): """hyper_init. :param functor_scheduler: """ parameters = {} parameters[self.name + "_alpha"] = self.alpha return functor_scheduler(trainer=None, **parameters) def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others): _ = others _ = tensor_y feat = self.extract_semantic_feat(tensor_x) net_grad_additive_reverse = AutoGradFunReverseMultiply.apply( feat, self.alpha ) logit_d = self.net_discriminator(net_grad_additive_reverse) _, d_target = tensor_d.max(dim=1) lc_d = F.cross_entropy(logit_d, d_target, reduction=g_str_cross_entropy_agg) return [lc_d], [self.alpha] return ModelDAN