Source code for domainlab.tasks.utils_task_dset

"""
task specific dataset operation
"""
import random

from torch.utils.data import Dataset


[docs] class DsetIndDecorator4XYD(Dataset): """ For dataset of x, y, d, decorate it wih index """ def __init__(self, dset): """ :param dset: x,y,d """ tuple_m = dset[0] if len(tuple_m) < 3: raise RuntimeError( "dataset to be wrapped should output at least x, y, and d; got length ", len(tuple_m), ) self.dset = dset def __getitem__(self, index): """ :param index: """ tensor_x, vec_y, vec_d, *_ = self.dset.__getitem__(index) return tensor_x, vec_y, vec_d, index def __len__(self): return self.dset.__len__()
[docs] class DsetZip(Dataset): """ enable zip return in getitem: x_1, y_1, x_2, y_2 to avoid always the same match, the second dataset does not use the same idx in __get__item() but instead, a random one """ def __init__(self, dset1, dset2, name=None): """ :param dset1: x1, y1, *d1 :param dset2: x2, y2, *d2 :param name: name of dataset """ self.dset1 = dset1 self.dset2 = dset2 self.name = name self.len2 = self.dset2.__len__() def __getitem__(self, idx): """ :param idx: """ idx2 = idx + random.randrange(self.len2) idx2 = idx2 % self.len2 tensor_x_1, vec_y_1, vec_d_1, *others_1 = self.dset1.__getitem__(idx) tensor_x_2, vec_y_2, vec_d_2, *others_2 = self.dset2.__getitem__(idx2) return ( tensor_x_1, vec_y_1, vec_d_1, others_1, tensor_x_2, vec_y_2, vec_d_2, others_2, ) def __len__(self): len1 = self.dset1.__len__() if len1 < self.len2: return len1 return self.len2