Source code for domainlab.tasks.utils_task

"""
Task wraps around datasets, this file provide utilities
"""
import os
from pathlib import Path

import numpy
import torch
import torchvision
from torch.utils.data import Dataset

from domainlab.utils.utils_class import store_args


[docs] class ImSize: """ImSize.""" @store_args def __init__(self, i_c, i_h, i_w): """ store channel, height, width """ @property def c(self): """image channel""" return self.i_c @property def h(self): """image height""" return self.i_h @property def w(self): """image width""" return self.i_w
[docs] def mk_onehot(dim, ind): """ :param dim: dimension of representation vector :param ind: index """ eye = torch.eye(dim) vec = eye[ind] return vec
[docs] def mk_loader(dset, bsize, drop_last=True, shuffle=True, num_workers=int(0)): """ :param bs: batch size """ if len(dset) < bsize: bsize = len(dset) loader = torch.utils.data.DataLoader( dataset=dset, batch_size=bsize, shuffle=shuffle, # @FIXME: shuffle must be true so the last incomplete batch get used in another epoch? num_workers=num_workers, # @FIXME: num_workers=int(0) can be slow? drop_last=drop_last, ) return loader
[docs] class DsetDomainVecDecorator(Dataset): """ decorate a pytorch dataset with a fixed vector representation of domain """ def __init__(self, dset, vec_domain, na_domain): """ :param dset: x, y :param vec_domain: vector representation of domain :param na_domain: string description of domain """ self.dset = dset self.vec_domain = vec_domain self.na_domain = na_domain @property def targets(self): """ return a list of all targets so class sample count is straight forward """ return self.dset.targets def __getitem__(self, idx): """ :param idx: """ tensor, vec_class, *other_vars = self.dset.__getitem__(idx) if other_vars: return (tensor, vec_class, self.vec_domain, *other_vars) return tensor, vec_class, self.vec_domain def __len__(self): """__len__.""" return self.dset.__len__()
[docs] class DsetDomainVecDecoratorImgPath(DsetDomainVecDecorator): """ Except returning x, y, d, additionally, the path of x is returned currently not in use since it is mostly important to print predictions together with path for the test domain """ def __getitem__(self, idx): """ :param idx: """ tensor, vec_class, path = self.dset.__getitem__(idx) return tensor, vec_class, self.vec_domain, path
[docs] class DsetClassVecDecorator(Dataset): """ decorate a pytorch dataset with a new class name """ def __init__(self, dset, dict_folder_name2class_global, list_str_y): """ :param dset: x, y, *d :param dict_folder2class: dictionary that maps class folder of domain to glbal class """ self.dset = dset self.class2idx = { k: v for (k, v) in self.dset.class_to_idx.items() if k in self.dset.list_class_dir } assert self.class2idx self.dict_folder_name2class_global = dict_folder_name2class_global self.list_str_y = list_str_y # inverst key:value to value:key for backward map self.dict_old_idx2old_class = dict((v, k) for k, v in self.class2idx.items()) dict_class_na_local2vec_new = dict( (k, self.fun_class_local_na2vec_new(k)) for k, v in self.class2idx.items() ) self.dict_class_na_local2vec_new = dict_class_na_local2vec_new @property def targets(self): """ return a list of all targets so class sample count is straight forward """ return self.dset.targets
[docs] def fun_class_local_na2vec_new(self, k): """ local class name within one domain, to one-hot vector of new representation """ ind = self.list_str_y.index(self.dict_folder_name2class_global[k]) return mk_onehot(len(self.list_str_y), ind)
def __getitem__(self, idx): """ :param idx: """ tensor, vec_class, *other_vars = self.dset.__getitem__(idx) vec_class = vec_class.numpy() ind_old = numpy.argmax(vec_class) class_local = self.dict_old_idx2old_class[ind_old] vec_class_new = self.dict_class_na_local2vec_new[class_local] return tensor, vec_class_new, *other_vars def __len__(self): """__len__.""" return self.dset.__len__()
[docs] class DsetClassVecDecoratorImgPath(DsetClassVecDecorator): def __getitem__(self, idx): """ :param idx: This function is mainly """ tensor, vec_class_new, path = super().__getitem__(idx) return tensor, vec_class_new, path[0]
[docs] class LoaderDomainLabel: """ wraps a dataset with domain label and into a loader """ def __init__(self, batch_size, dim_d): """__init__. :param batch_size: :param dim_d: """ self.batch_size = batch_size self.dim_d = dim_d def __call__(self, dset, d_ind, na_domain): """ wrap_dataset2loader_with_domain_label. :param dataset: :param batch_size: :param d_dim: :param d_ind: """ d_eye = torch.eye(self.dim_d) d_label = d_eye[d_ind] dset = DsetDomainVecDecorator(dset, d_label, na_domain) loader = mk_loader(dset, self.batch_size) return loader
[docs] def tensor1hot2ind(tensor_label): """tensor1hot2ind. :param tensor_label: """ _, label_ind = torch.max(tensor_label, dim=1) npa_label_ind = label_ind.numpy() return npa_label_ind
# @FIXME: this function couples strongly with the task, # should be a class method of task
[docs] def img_loader2dir( loader, folder, test=False, list_domain_na=None, list_class_na=None, batches=5 ): """ save images from loader to directory so speculate if loader is correct :param loader: pytorch data loader :param folder: folder to save images :param test: if true, the loader is assumend to be a test loader; if false (default) it is assumed to be a train loader :param list_domain_na: optional list of domain names :param list_class_na: optional list of class names :param batches: number of batches to save """ Path(os.path.normpath(folder)).mkdir(parents=True, exist_ok=True) l_iter = iter(loader) counter = 0 batches = min(batches, len(l_iter)) for _ in range(batches): img, vec_y, *other_vars = next(l_iter) class_label_ind_batch = tensor1hot2ind(vec_y) # get domain label # Note 1: test loaders don't return domain labels (see NodeTaskDictClassif.init_business) # Note 2: for train loaders domain label will be the 0th element of other_vars (see DsetDomainVecDecorator class above) has_domain_label_ind = False if not test: if other_vars: domain_label_ind_batch = tensor1hot2ind(other_vars[0]) has_domain_label_ind = True for b_ind in range(img.shape[0]): class_label_ind = class_label_ind_batch[b_ind] class_label_scalar = class_label_ind.item() if list_class_na is None: str_class_label = "class_" + str(class_label_scalar) else: # @FIXME: where is the correspndance between # class ind_label and class str_label? str_class_label = list_class_na[class_label_scalar] str_domain_label = "unknown" if has_domain_label_ind: domain_label_ind = domain_label_ind_batch[b_ind] if list_domain_na is None: str_domain_label = str(domain_label_ind) else: # @FIXME: the correspondance between # domain ind_label and domain str_label is missing str_domain_label = list_domain_na[domain_label_ind] arr = img[b_ind] img_vision = torchvision.transforms.ToPILImage()(arr) f_n = "_".join( [ "class", str_class_label, "domain", str_domain_label, "n", str(counter), ] ) counter += 1 path = os.path.join(folder, f_n + ".png") img_vision.save(path)