Source code for domainlab.utils.flows_gen_img_model

import os

import torch

from domainlab.compos.vae.c_vae_recon import ReconVAEXYD
from domainlab.utils.utils_img_sav import mk_fun_sav_img


[docs] class FlowGenImgs: def __init__(self, model, device): model = model.to(device) self.obj_recon = ReconVAEXYD(model) self.sav_fun = None def _save_list(self, recon_list_d, device, img, name): """ compare counterfactual """ recon_stack_d = torch.cat(recon_list_d) recon_stack_d = recon_stack_d.to(device) comparison = torch.cat([img, recon_stack_d]) self.sav_fun(comparison, name, title=None) # row is actually column here def _save_pair(self, x_recon_img, device, img, name): """ compare recon and original """ x_recon_img = x_recon_img.to(device) comparison = torch.cat([img, x_recon_img]) self.sav_fun(comparison, name)
[docs] def gen_img_loader(self, loader, device, path, domain): """ gen image for the first batch of input loader """ for _, (x_batch, y_batch, *d_batch) in enumerate(loader): x_batch, y_batch = x_batch.to(device), y_batch.to(device) if d_batch and isinstance(d_batch[0], torch.Tensor): d_batch = d_batch[0].to(device) else: d_batch = None self.gen_img_xyd(x_batch, y_batch, d_batch, device, path, folder_na=domain) break
[docs] def gen_img_xyd(self, img, vec_y, vec_d, device, path, folder_na): img = img.to(device) vec_y = vec_y.to(device) if vec_d is not None: vec_d = vec_d.to(device) nrow = img.shape[0] # nrow is actually ncol in pytorch save_image!! self.sav_fun = mk_fun_sav_img(path=path, nrow=nrow, folder_na=folder_na) self._flow_vanilla(img, vec_y, vec_d, device) self._flow_cf_y(img, vec_y, vec_d, device) if vec_d is not None: self._flow_cf_d(img, vec_y, vec_d, device)
def _flow_vanilla(self, img, vec_y, vec_d, device, num_sample=10): x_recon_img, str_type = self.obj_recon.recon(img) self._save_pair(x_recon_img, device, img, str_type + ".png") x_recon_img, str_type = self.obj_recon.recon(img, vec_y) self._save_pair(x_recon_img, device, img, str_type + ".png") if vec_d is not None: x_recon_img, str_type = self.obj_recon.recon(img, None, vec_d) self._save_pair(x_recon_img, device, img, str_type + ".png") for i in range(num_sample): x_recon_img, str_type = self.obj_recon.recon(img, vec_y, vec_d, True, True) x_recon_img = x_recon_img.to(device) comparison = torch.cat([img, x_recon_img]) self.sav_fun(comparison, str_type + str(i) + ".png") def _flow_cf_y(self, img, vec_y, vec_d, device): """ scan possible values of vec_y """ recon_list, str_type = self.obj_recon.recon_cf( img, "y", vec_y.shape[1], device, zx2fill=None ) self._save_list( recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png" ) recon_list, str_type = self.obj_recon.recon_cf( img, "y", vec_y.shape[1], device, zx2fill=0 ) self._save_list( recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png" ) if vec_d is not None: recon_list, str_type = self.obj_recon.recon_cf( img, "y", vec_y.shape[1], device, vec_d=vec_d, zx2fill=0 ) self._save_list( recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png" ) def _flow_cf_d(self, img, vec_y, vec_d, device): """ scan possible values of vec_y """ recon_list, str_type = self.obj_recon.recon_cf( img, "d", vec_d.shape[1], device, zx2fill=None ) self._save_list( recon_list, device, img, "_".join(["recon_cf_d", str_type]) + ".png" ) recon_list, str_type = self.obj_recon.recon_cf( img, "d", vec_d.shape[1], device, zx2fill=0 ) self._save_list( recon_list, device, img, "_".join(["recon_cf_d", str_type]) + ".png" )
[docs] def fun_gen(model, device, node, args, subfolder_na, output_folder_na="gen"): flow = FlowGenImgs(model, device) path = os.path.join( args.out, output_folder_na, node.task_name, args.model, subfolder_na ) flow.gen_img_loader(node.loader_te, device, path=path, domain="_".join(args.te_d)) flow.gen_img_loader( node.loader_tr, device, path=path, domain="_".join(node.list_domain_tr) )