Source code for domainlab.tasks.task_folder

"""
When class names and numbers does not match across different domains
"""
from torchvision import transforms

from domainlab.dsets.dset_subfolder import DsetSubFolder
from domainlab.dsets.utils_data import (
    DsetInMemDecorator,
    fun_img_path_loader_default,
    mk_fun_label2onehot,
)
from domainlab.tasks.b_task_classif import NodeTaskDictClassif
from domainlab.tasks.utils_task import DsetClassVecDecoratorImgPath
from domainlab.utils.logger import Logger


[docs] class NodeTaskFolder(NodeTaskDictClassif): """ create dataset by loading files from an organized folder then each domain correspond to one dataset """ @property def dict_domain2imgroot(self): """ {"domain name":"xx/yy/zz"} """ return self._dict_domains2imgroot @dict_domain2imgroot.setter def dict_domain2imgroot(self, dict_root): """ {"domain name":"xx/yy/zz"} """ if not isinstance(dict_root, dict): raise RuntimeError("input is not diciontary") self._dict_domains2imgroot = dict_root @property def extensions(self): """ return allowed extensions """ return self.dict_att["img_extensions"] @extensions.setter def extensions(self, str_format): self.dict_att["img_extensions"] = str_format
[docs] def get_dset_by_domain(self, args, na_domain, split=False): if float(args.split): raise RuntimeError( "this task does not support spliting training domain yet" ) if self._dict_domain_img_trans: trans = self._dict_domain_img_trans[na_domain] if na_domain not in self.list_domain_tr: trans = self.img_trans_te else: trans = transforms.ToTensor() dset = DsetSubFolder( root=self.dict_domain2imgroot[na_domain], list_class_dir=self.list_str_y, loader=fun_img_path_loader_default, extensions=self.extensions, transform=trans, target_transform=mk_fun_label2onehot(len(self.list_str_y)), ) return dset, dset # @FIXME: validation by default set to be training set
[docs] class NodeTaskFolderClassNaMismatch(NodeTaskFolder): """ when the folder names of the same class from different domains have different names """
[docs] def get_dset_by_domain(self, args, na_domain, split=False): if float(args.split): raise RuntimeError( "this task does not support spliting training domain yet" ) logger = Logger.get_logger() logger.info(f"reading domain: {na_domain}") domain_class_dirs = self._dict_domain_folder_name2class[na_domain].keys() if self._dict_domain_img_trans: trans = self._dict_domain_img_trans[na_domain] if na_domain not in self.list_domain_tr: trans = self.img_trans_te else: trans = transforms.ToTensor() ext = None if self.extensions is None else self.extensions[na_domain] dset = DsetSubFolder( root=self.dict_domain2imgroot[na_domain], list_class_dir=list(domain_class_dirs), loader=fun_img_path_loader_default, extensions=ext, transform=trans, target_transform=mk_fun_label2onehot(len(self.list_str_y)), ) # dset.path2imgs dict_folder_name2class_global = self._dict_domain_folder_name2class[na_domain] dset = DsetClassVecDecoratorImgPath( dset, dict_folder_name2class_global, self.list_str_y ) # Always use the DsetInMemDecorator at the last step # since it does not have other needed attributes in bewteen if args.dmem: dset = DsetInMemDecorator(dset, na_domain) return dset, dset # @FIXME: validation by default set to be training set