Source code for domainlab.algos.builder_hduva
"""
build hduva model, get trainer from cmd arguments
"""
from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter
from domainlab.compos.pcr.request import RequestVAEBuilderCHW
from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter
from domainlab.models.model_hduva import mk_hduva
from domainlab.utils.utils_cuda import get_device
[docs]
class NodeAlgoBuilderHDUVA(NodeAlgoBuilder):
"""
NodeAlgoBuilderHDUVA
"""
[docs]
def init_business(self, exp):
"""
return trainer, model, observer
"""
task = exp.task
args = exp.args
task.get_list_domains_tr_te(args.tr_d, args.te_d)
request = RequestVAEBuilderCHW(task.isize.c, task.isize.h, task.isize.w, args)
device = get_device(args)
node = VAEChainNodeGetter(request, args.topic_dim)()
model = mk_hduva(list_str_y=task.list_str_y)(
node,
zd_dim=args.zd_dim,
zy_dim=args.zy_dim,
zx_dim=args.zx_dim,
device=device,
topic_dim=args.topic_dim,
gamma_d=args.gamma_d,
gamma_y=args.gamma_y,
beta_t=args.beta_t,
beta_x=args.beta_x,
beta_y=args.beta_y,
beta_d=args.beta_d,
)
model = self.init_next_model(model, exp)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
observer = ObVisitorCleanUp(ObVisitor(model_sel))
trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler")
trainer.init_business(model, task, observer, device, args)
return trainer, model, observer, device