Source code for domainlab.compos.vae.c_vae_recon

"""
Adaptor is vital for data generation so this class can be decoupled from model class.
The model class can be refactored, we do want to use the trained old-version model, which we
only need to change adaptor class.
"""
import torch

from domainlab.compos.vae.c_vae_adaptor_model_recon import AdaptorReconVAEXYD


[docs] class ReconVAEXYD: """ Adaptor is vital for data generation so this class can be decoupled from model class. The model class can be refactored, we do want to use the trained old-version model, which we only need to change adaptor class. """ def __init__(self, model, na_adaptor=AdaptorReconVAEXYD): self.model = model self.adaptor = na_adaptor(self.model)
[docs] def recon( self, x, vec_y=None, vec_d=None, sample_p_zy=False, sample_p_zd=False, scalar_zx2fill=None, ): """ common function """ str_type = "" with torch.no_grad(): q_zd, q_zx, q_zy = self.adaptor.cal_latent(x) if vec_d is None: recon_zd = q_zd.mean str_type = "zd_q" else: p_zd = self.adaptor.cal_prior_zd(vec_d) recon_zd = p_zd.mean str_type = "zd_p" if sample_p_zd: recon_zd = p_zd.rsample() str_type = "_".join([str_type, "sample__"]) zx_loc_q = q_zx.mean if scalar_zx2fill is not None: recon_zx = torch.zeros_like(zx_loc_q) recon_zx = recon_zx.fill_(scalar_zx2fill) str_type = "_".join( [str_type, "__fill_zx_", str(scalar_zx2fill), "___"] ) else: recon_zx = zx_loc_q str_type = "_".join([str_type, "__zx_q__"]) if vec_y is not None: p_zy = self.adaptor.cal_prior_zy(vec_y) recon_zy = p_zy.mean str_type = "_".join([str_type, "__zy_p__"]) if sample_p_zy: recon_zy = p_zy.rsample() str_type = "_".join([str_type, "sample__"]) else: recon_zy = q_zy.mean str_type = "_".join([str_type, "zy_q__"]) img_recon = self.adaptor.recon_ydx(recon_zy, recon_zd, recon_zx, x) return img_recon, str_type
[docs] def recon_cf( self, x, na_cf, dim_cf, device, vec_y=None, vec_d=None, zx2fill=None, sample_p_zy=False, sample_p_zd=False, ): """ Countefactual reconstruction: :param na_cf: name of counterfactual, 'y' or 'd' :param dim_cf: dimension of counterfactual factor """ list_recon_cf = [] batch_size = x.shape[0] str_type = None # Counterfactual image generation for i in range(dim_cf): label_cf = torch.zeros(batch_size, dim_cf).to(device) label_cf[:, i] = 1 if na_cf == "y": img_recon_cf, str_type = self.recon( x, label_cf, vec_d, sample_p_zy, sample_p_zd, zx2fill ) elif na_cf == "d": img_recon_cf, str_type = self.recon( x, vec_y, label_cf, sample_p_zy, sample_p_zd, zx2fill ) else: raise RuntimeError( "counterfactual image generation can only be 'y' or 'd'" ) list_recon_cf.append(img_recon_cf) return list_recon_cf, str_type