Source code for domainlab.tasks.b_task

"""
Use dictionaries to create train and test domain split
"""
from torch.utils.data.dataset import ConcatDataset

from domainlab.tasks.a_task import NodeTaskDG
from domainlab.tasks.utils_task import DsetDomainVecDecorator, mk_loader, mk_onehot


[docs] class NodeTaskDict(NodeTaskDG): """ Use dictionaries to create train and test domain split """
[docs] def get_dset_by_domain(self, args, na_domain, split=False): """ each domain correspond to one dataset, must be implemented by child class """ raise NotImplementedError # it is safe for each subclass to implement this
[docs] def decorate_dset(self, model, args): """ dispatch re-organization of data flow to model """
[docs] def init_business(self, args, trainer=None): """ create a dictionary of datasets """ list_domain_tr, list_domain_te = self.get_list_domains_tr_te( args.tr_d, args.te_d ) self.dict_dset_tr = {} self.dict_dset_val = {} dim_d = len(list_domain_tr) for ind_domain_dummy, na_domain in enumerate(list_domain_tr): dset_tr, dset_val = self.get_dset_by_domain( args, na_domain, split=args.split ) vec_domain = mk_onehot(dim_d, ind_domain_dummy) # for diva, dann ddset_tr = DsetDomainVecDecorator(dset_tr, vec_domain, na_domain) ddset_val = DsetDomainVecDecorator(dset_val, vec_domain, na_domain) if trainer is not None and hasattr(trainer, "dset_decoration_args_algo"): ddset_tr = trainer.dset_decoration_args_algo(args, ddset_tr) ddset_val = trainer.dset_decoration_args_algo(args, ddset_val) if ( trainer is not None and trainer.model is not None and hasattr(trainer.model, "dset_decoration_args_algo") ): ddset_tr = trainer.model.dset_decoration_args_algo(args, ddset_tr) ddset_val = trainer.model.dset_decoration_args_algo(args, ddset_val) self.dict_dset_tr.update({na_domain: ddset_tr}) self.dict_loader_tr.update({na_domain: mk_loader(ddset_tr, args.bs)}) self.dict_dset_val.update({na_domain: ddset_val}) ddset_mix = ConcatDataset(tuple(self.dict_dset_tr.values())) flag_shuffling = not args.shuffling_off # args.shuffling_off default is False -> not False -> True self._loader_tr = mk_loader(ddset_mix, args.bs, shuffle=flag_shuffling) ddset_mix_val = ConcatDataset(tuple(self.dict_dset_val.values())) self._loader_val = mk_loader( ddset_mix_val, args.bs, shuffle=False, drop_last=False ) self.dict_dset_te = {} # No need to have domain Label for test for na_domain in list_domain_te: dset_te, *_ = self.get_dset_by_domain(args, na_domain, split=False) # NOTE: since get_dset_by_domain always return two datasets, # train and validation, this is not needed in test domain self.dict_dset_te.update({na_domain: dset_te}) dset_te = ConcatDataset(tuple(self.dict_dset_te.values())) self._loader_te = mk_loader(dset_te, args.bs, shuffle=False, drop_last=False)