Source code for domainlab.tasks.task_pathlist

"""
The class TaskPathList provides the user an interface to provide a file with
each line consisting of a pair, where the first slot contains the path
(either absolute or relative if the user knows from where this package is
executed)
of an image and the second slot contains the class label as a numerical string.
"""
import os

import torch.multiprocessing

from domainlab.dsets.dset_img_path_list import DsetImPathList
from domainlab.dsets.utils_data import mk_fun_label2onehot
from domainlab.tasks.b_task_classif import NodeTaskDictClassif

torch.multiprocessing.set_sharing_strategy("file_system")
# "too many opened files" https://github.com/pytorch/pytorch/issues/11201


[docs] class NodeTaskPathListDummy(NodeTaskDictClassif): """ typedef class so that other function can use isinstance """
[docs] def get_dset_by_domain(self, args, na_domain, split=False): raise NotImplementedError
[docs] def mk_node_task_path_list( isize, img_trans_te, list_str_y, img_trans_tr, dict_class_label_ind2name, dict_domain2imgroot, dict_d2filepath_list_img_tr, dict_d2filepath_list_img_val, dict_d2filepath_list_img_te, succ=None, ): """mk_node_task_path_list. :param isize: :param list_str_y: :param img_trans_tr: :param dict_class_label_ind2name: :param dict_domain2imgroot: :param dict_d2filepath_list_img_tr: :param dict_d2filepath_list_img_val: :param dict_d2filepath_list_img_te: :param succ: """ class NodeTaskPathList(NodeTaskPathListDummy): """ The class TaskPathList provides the user an interface to provide a file with each line consisting of a pair separated by comma, where the first slot contains the path (either absolute or relative if the user knows from where this package is executed) of an image and the second slot contains the class label as a numerical string. e.g.: /path/2/file/art_painting/dog/pic_376.jpg 1 """ def _get_complete_domain(self, na_domain, dict_domain2pathfilepath): """_get_complete_domain. :param na_domain: """ if na_domain not in self.list_domain_tr: trans = img_trans_te else: if self._dict_domain_img_trans: trans = self._dict_domain_img_trans[na_domain] else: trans = img_trans_tr root_img = self.dict_domain2imgroot[na_domain] path2filelist = dict_domain2pathfilepath[na_domain] path2filelist = os.path.expanduser(path2filelist) root_img = os.path.expanduser(root_img) dset = DsetImPathList( root_img, path2filelist, trans_img=trans, trans_target=mk_fun_label2onehot(len(self.list_str_y)), ) return dset def get_dset_by_domain(self, args, na_domain, split=True): """get_dset_by_domain. :param args: :param na_domain: :param split: for test set, use the whole """ if not split: # no train/val split for test domain # the user is required to input tr, val, te file path # if split=False, then only te is used, which contains # the whole dataset dset = self._get_complete_domain( na_domain, self._dict_domain2filepath_list_im_te ) # test set contains train+validation return dset, dset # @FIXME: avoid returning two identical dset = self._get_complete_domain( na_domain, # read training set from user configuration self._dict_domain2filepath_list_im_tr, ) dset_val = self._get_complete_domain( na_domain, # read validation set from user configuration self._dict_domain2filepath_list_im_val, ) return dset, dset_val def conf(self): """ set task attribute in initialization """ self.list_str_y = list_str_y self.isize = isize self.dict_class_label_ind2name = dict_class_label_ind2name self.dict_domain2imgroot = dict_domain2imgroot self._dict_domain2filepath_list_im_tr = dict_d2filepath_list_img_tr self._dict_domain2filepath_list_im_val = dict_d2filepath_list_img_val self._dict_domain2filepath_list_im_te = dict_d2filepath_list_img_te self.set_list_domains(list(self.dict_domain2imgroot.keys())) def __init__(self, succ=None): super().__init__(succ) self.conf() return NodeTaskPathList(succ)