Source code for domainlab.models.model_vae_xyd_classif
"""
Base Class for XYD VAE Classify
"""
from domainlab.models.a_model_classif import AModelClassif
from domainlab.models.interface_vae_xyd import InterfaceVAEXYD
from domainlab.utils.utils_class import store_args
[docs]
class VAEXYDClassif(AModelClassif, InterfaceVAEXYD):
    """
    Base Class for DIVA and HDUVA
    """
    @store_args
    def __init__(self, chain_node_builder, zd_dim, zy_dim, zx_dim, **kwargs):
        """
        :param chain_node_builder: constructed object
        """
        for key, value in kwargs.items():
            if key == "list_str_y":
                list_str_y = value
        super().__init__(net_classifier=None, list_str_y=list_str_y)
        self.init()
        self._net_classifier = self.net_classif_y 
    @property
    def multiplier4task_loss(self):
        """
        the multiplier for task loss is default to 1.0 except for vae family models
        """
        return self.gamma_y
    def _init_components(self):
        super()._init_components()
        self.add_module(
            "net_classif_y",
            self.chain_node_builder.construct_classifier(self.zy_dim, self.dim_y),
        )
 
        # property setter only for other object, internally, one shoud use _net_classifier