Source code for domainlab.utils.test_img

import torch


[docs] def mk_img(i_h, i_ch=3, batch_size=5): img = torch.rand(i_h, i_h) # uniform distribution [0,1] # x = torch.clamp(x, 0, 1) img.unsqueeze_(0) img = img.repeat(i_ch, 1, 1) # RGB image img.unsqueeze_(0) img = img.repeat(batch_size, 1, 1, 1) return img
[docs] def mk_rand_label_onehot(target_dim=10, batch_size=5): label_scalar = torch.randint(high=target_dim, size=(batch_size,)) label_scalar2 = label_scalar.unsqueeze(1) label_zeros = torch.zeros(batch_size, target_dim) label_onehot = torch.scatter( input=label_zeros, dim=1, index=label_scalar2, value=1.0 ) return label_onehot
[docs] def mk_rand_xyd(ims, y_dim, d_dim, batch_size): imgs = mk_img(i_h=ims, batch_size=batch_size) ys = mk_rand_label_onehot(target_dim=y_dim, batch_size=batch_size) ds = mk_rand_label_onehot(target_dim=d_dim, batch_size=batch_size) return imgs, ys, ds