Source code for domainlab.algos.builder_jigen1

"""
builder for JiGen
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential
from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler
from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter
from domainlab.compos.nn_zoo.net_classif import ClassifDropoutReluLinear
from domainlab.compos.utils_conv_get_flat_dim import get_flat_dim
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches
from domainlab.models.model_jigen import mk_jigen
from domainlab.utils.utils_cuda import get_device
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg


[docs] class NodeAlgoBuilderJiGen(NodeAlgoBuilder): """ NodeAlgoBuilderJiGen """
[docs] def init_business(self, exp): """ return trainer, model, observer """ task = exp.task args = exp.args device = get_device(args) msel = MSelOracleVisitor(msel=MSelValPerf(max_es=args.es), val_threshold=args.val_threshold) observer = ObVisitor(msel) observer = ObVisitorCleanUp(observer) builder = FeatExtractNNBuilderChainNodeGetter( args, arg_name_of_net="nname", arg_path_of_net="npath" )() # request, @FIXME, constant string net_encoder = builder.init_business( flag_pretrain=True, dim_out=task.dim_y, remove_last_layer=False, args=args, isize=(task.isize.i_c, task.isize.i_w, task.isize.i_h), ) dim_feat = get_flat_dim( net_encoder, task.isize.i_c, task.isize.i_h, task.isize.i_w ) net_classifier = ClassifDropoutReluLinear(dim_feat, task.dim_y) # @FIXME: this seems to be the only difference w.r.t. builder_dann net_classifier_perm = ClassifDropoutReluLinear(dim_feat, args.nperm + 1) model = mk_jigen( list_str_y=task.list_str_y, net_classifier=net_classifier)( coeff_reg=get_gamma_reg(args, 'jigen'), net_encoder=net_encoder, net_classifier_permutation=net_classifier_perm, n_perm=args.nperm, prob_permutation=args.pperm, ) model = self.init_next_model(model, exp) trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler") trainer.init_business(model, task, observer, device, args) if isinstance(trainer, TrainerHyperScheduler): trainer.set_scheduler( HyperSchedulerWarmupExponential, total_steps=trainer.num_batches * args.warmup, flag_update_epoch=False, flag_update_batch=True, ) return trainer, model, observer, device