Source code for domainlab.dsets.utils_data
"""
Utilities for dataset
"""
import datetime
import torch
import torch.utils.data as data_utils
from PIL import Image
from torch.utils.data import Dataset
from torchvision.utils import save_image
from domainlab.utils.logger import Logger
[docs]
def fun_img_path_loader_default(path):
"""
https://discuss.pytorch.org/t/handling-rgba-images/88428/4
"""
return Image.open(path).convert("RGB")
[docs]
def mk_fun_label2onehot(dim):
"""
function generator
index to onehot
"""
def fun_label2onehot(label):
"""
:param label:
"""
m_eye = torch.eye(dim)
return m_eye[label]
return fun_label2onehot
[docs]
def plot_ds(dset, f_name, batchsize=32):
"""
:param dset:
:param f_name:
:param batchsize: batch_size
"""
loader_tr = data_utils.DataLoader(dset, batch_size=batchsize, shuffle=False)
for _, (img, _, *_) in enumerate(loader_tr):
nrow = min(img.size(0), 8)
save_image(img.cpu(), f_name, nrow=nrow)
break # only one batch
[docs]
def plot_ds_list(ds_list, f_name, batchsize=8, shuffle=False):
"""
plot list of datasets, each datasets in one row
:param ds_list:
:param fname:
:param batchsize:
:param shuffle:
"""
list_imgs = []
for dset in ds_list:
loader = data_utils.DataLoader(dset, batch_size=batchsize, shuffle=shuffle)
for _, (img, _, *_) in enumerate(loader):
list_imgs.append(img)
break
comparison = torch.cat(list_imgs)
save_image(comparison.cpu(), f_name, nrow=batchsize)
[docs]
class DsetInMemDecorator(Dataset):
"""
fetch all items of a dataset into memory
"""
def __init__(self, dset, name=None):
"""
:param dset: x, y, *d
:param name: name of dataset
"""
self.dset = dset
self.item_list = []
logger = Logger.get_logger()
if name is not None:
logger.info(f"loading dset {name}")
t_0 = datetime.datetime.now()
for i in range(len(self.dset)):
self.item_list.append(self.dset[i])
t_1 = datetime.datetime.now()
logger.info(f"loading dataset to memory taken: {t_1 - t_0}")
def __getitem__(self, idx):
"""
:param idx:
"""
return self.item_list[idx]
def __len__(self):
return self.dset.__len__()