Source code for domainlab.algos.builder_erm

"""
builder for erm
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter
from domainlab.algos.utils import split_net_feat_last
from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter
from domainlab.models.model_erm import mk_erm
from domainlab.utils.utils_cuda import get_device


[docs] class NodeAlgoBuilderERM(NodeAlgoBuilder): """ builder for erm """
[docs] def init_business(self, exp): """ return trainer, model, observer """ task = exp.task args = exp.args device = get_device(args) model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold) observer = ObVisitor(model_sel) builder = FeatExtractNNBuilderChainNodeGetter( args, arg_name_of_net="nname", arg_path_of_net="npath" )() # request, # @FIXME, constant string net = builder.init_business( flag_pretrain=True, dim_out=task.dim_y, remove_last_layer=False, args=args, isize=(task.isize.i_c, task.isize.i_h, task.isize.i_w), ) _, _ = split_net_feat_last(net) model = mk_erm(list_str_y=task.list_str_y)( net=net # net_feat=net_invar_feat, net_classifier=net_classifier, ) model = self.init_next_model(model, exp) trainer = TrainerChainNodeGetter(args.trainer)(default="basic") # trainer.init_business(model, task, observer, device, args) return trainer, model, observer, device