Source code for domainlab.algos.builder_dann

"""
builder for Domain Adversarial Neural Network: accept different training scheme
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential
from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter
from domainlab.compos.nn_zoo.net_classif import ClassifDropoutReluLinear
from domainlab.compos.utils_conv_get_flat_dim import get_flat_dim
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.models.model_dann import mk_dann
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


[docs] class NodeAlgoBuilderDANN(NodeAlgoBuilder): """ NodeAlgoBuilderDANN """
[docs] def init_business(self, exp): """ return trainer, model, observer """ task = exp.task self._task = task args = exp.args task.get_list_domains_tr_te(args.tr_d, args.te_d) device = get_device(args) msel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold) observer = ObVisitor(msel) observer = ObVisitorCleanUp(observer) builder = FeatExtractNNBuilderChainNodeGetter( args, arg_name_of_net="nname", arg_path_of_net="npath" )() # request, @FIXME, constant string net_encoder = builder.init_business( flag_pretrain=True, dim_out=task.dim_y, remove_last_layer=False, args=args, isize=(task.isize.i_c, task.isize.i_w, task.isize.i_h), ) dim_feat = get_flat_dim( net_encoder, task.isize.i_c, task.isize.i_h, task.isize.i_w ) net_classifier = ClassifDropoutReluLinear(dim_feat, task.dim_y) net_discriminator = self.reset_aux_net(net_encoder) model = mk_dann(list_str_y=task.list_str_y, net_classifier=net_classifier)( list_d_tr=task.list_domain_tr, alpha=get_gamma_reg(args, 'dann'), net_encoder=net_encoder, net_discriminator=net_discriminator, builder=self) model = self.init_next_model(model, exp) trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler") trainer.init_business(model, task, observer, device, args) if trainer.name == "hyperscheduler": trainer.set_scheduler( HyperSchedulerWarmupExponential, total_steps=trainer.num_batches * args.warmup, flag_update_epoch=False, flag_update_batch=True, ) return trainer, model, observer, device
[docs] def reset_aux_net(self, net_encoder): """ reset auxilliary neural network from task note that net_encoder can also be a method like extract_semantic_feat """ dim_feat = get_flat_dim( net_encoder, self._task.isize.i_c, self._task.isize.i_h, self._task.isize.i_w, ) net_discriminator = ClassifDropoutReluLinear( dim_feat, len(self._task.list_domain_tr) ) return net_discriminator