Source code for domainlab.dsets.dset_img_path_list
import os
import torch.utils.data as data
from domainlab.dsets.utils_data import fun_img_path_loader_default
from domainlab.utils.utils_class import store_args
[docs]
class DsetImPathList(data.Dataset):
@store_args
def __init__(self, root_img, path2filelist, trans_img=None, trans_target=None):
"""
one file provide image path and label which forms a domain
"""
self.list_tuple_img_label = []
self.get_list_tuple_img_label()
[docs]
def get_list_tuple_img_label(self):
with open(self.path2filelist, "r") as f_h:
for str_line in f_h.readlines():
path_img, label_img = str_line.strip().split()
self.list_tuple_img_label.append(
(path_img, int(label_img))
) # @FIXME: string to int, not necessarily continuous
def __getitem__(self, index):
path_img, target = self.list_tuple_img_label[index]
target = target - 1 # @FIXME: make this more general
img = fun_img_path_loader_default(os.path.join(self.root_img, path_img))
if self.trans_img is not None:
img = self.trans_img(img)
if self.trans_target is not None:
target = self.trans_target(target)
return img, target, path_img
def __len__(self):
return len(self.list_tuple_img_label)