Source code for domainlab.models.interface_vae_xyd

"""
Base Class for XYD VAE
"""
import torch
import torch.distributions as dist

from domainlab.utils.utils_class import store_args


[docs] class InterfaceVAEXYD: """ Interface (without constructor and inheritance) for XYD VAE """
[docs] def init(self): self.chain_node_builder.init_business(self.zd_dim, self.zx_dim, self.zy_dim) self.i_c = self.chain_node_builder.i_c self.i_h = self.chain_node_builder.i_h self.i_w = self.chain_node_builder.i_w self._init_components()
def _init_components(self): """ q(z|x) p(zy) q_{classif}(zy) """ self.add_module("encoder", self.chain_node_builder.build_encoder()) self.add_module("decoder", self.chain_node_builder.build_decoder()) self.add_module( "net_p_zy", self.chain_node_builder.construct_cond_prior(self.dim_y, self.zy_dim), )
[docs] def init_p_zx4batch(self, batch_size, device): """ 1. Generate pytorch distribution object. 2. To be called by trainer :param batch_size: :param device: """ # p(zx): isotropic gaussian zx_p_loc = torch.zeros(batch_size, self.zx_dim).to(device) zx_p_scale = torch.ones(batch_size, self.zx_dim).to(device) p_zx = dist.Normal(zx_p_loc, zx_p_scale) return p_zx