Source code for domainlab.models.model_vae_xyd_classif
"""
Base Class for XYD VAE Classify
"""
from domainlab.models.a_model_classif import AModelClassif
from domainlab.models.interface_vae_xyd import InterfaceVAEXYD
from domainlab.utils.utils_class import store_args
[docs]
class VAEXYDClassif(AModelClassif, InterfaceVAEXYD):
"""
Base Class for DIVA and HDUVA
"""
@store_args
def __init__(self, chain_node_builder, zd_dim, zy_dim, zx_dim, **kwargs):
"""
:param chain_node_builder: constructed object
"""
for key, value in kwargs.items():
if key == "list_str_y":
list_str_y = value
super().__init__(net_classifier=None, list_str_y=list_str_y)
self.init()
self._net_classifier = self.net_classif_y
@property
def multiplier4task_loss(self):
"""
the multiplier for task loss is default to 1.0 except for vae family models
"""
return self.gamma_y
def _init_components(self):
super()._init_components()
self.add_module(
"net_classif_y",
self.chain_node_builder.construct_classifier(self.zy_dim, self.dim_y),
)
# property setter only for other object, internally, one shoud use _net_classifier