Source code for domainlab.compos.vae.zoo_vae_builders_classif

"""
Chain node VAE builders
"""
from domainlab.compos.vae.c_vae_builder_classif import (
    ChainNodeVAEBuilderClassifCondPrior,
)
from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import (
    DecoderConcatLatentFCReshapeConvGatedConv,
)
from domainlab.compos.vae.compos.encoder_xyd_parallel import (
    XYDEncoderParallelAlex,
    XYDEncoderParallelConvBnReluPool,
    XYDEncoderParallelExtern,
    XYDEncoderParallelUser,
)


[docs] class ChainNodeVAEBuilderClassifCondPriorBase(ChainNodeVAEBuilderClassifCondPrior): """ base class of AE builder """
[docs] def config_img(self, flag, request): """config_img. :param flag: :param request: """ if flag: self.i_c = request.i_c self.i_h = request.i_h self.i_w = request.i_w
[docs] def is_myjob(self, request): """is_myjob. :param request: """ raise NotImplementedError
[docs] def build_encoder(self): """build_encoder.""" raise NotImplementedError
[docs] def build_decoder(self): """build_decoder.""" decoder = DecoderConcatLatentFCReshapeConvGatedConv( z_dim=self.zd_dim + self.zx_dim + self.zy_dim, i_c=self.i_c, i_w=self.i_w, i_h=self.i_h, ) return decoder
[docs] class NodeVAEBuilderArg(ChainNodeVAEBuilderClassifCondPriorBase): """Build encoder decoder according to commandline arguments"""
[docs] def is_myjob(self, request): """is_myjob. :param request: """ self.request = request self.args = request.args self.config_img(True, request) if self.args.npath is not None or self.args.npath_dom is not None: return True return False
[docs] def build_encoder(self): """build_encoder.""" encoder = XYDEncoderParallelExtern( self.zd_dim, self.zx_dim, self.zy_dim, args=self.args, i_c=self.i_c, i_h=self.i_h, i_w=self.i_w, ) return encoder
[docs] class NodeVAEBuilderUser(ChainNodeVAEBuilderClassifCondPriorBase): """Build encoders according to test_mk_exp file"""
[docs] def is_myjob(self, request): flag = not hasattr(request, "args") self.request = request self.config_img(flag, request) return flag
[docs] def build_encoder(self): encoder = XYDEncoderParallelUser( self.request.net_class_d, self.request.net_x, self.request.net_class_y ) return encoder
[docs] class NodeVAEBuilderImgConvBnPool(ChainNodeVAEBuilderClassifCondPriorBase):
[docs] def is_myjob(self, request): """is_myjob. :param request: """ flag = ( request.args.nname == "conv_bn_pool_2" or request.args.nname_dom == "conv_bn_pool_2" ) # @FIXME self.config_img(flag, request) return flag
[docs] def build_encoder(self): """build_encoder.""" encoder = XYDEncoderParallelConvBnReluPool( self.zd_dim, self.zx_dim, self.zy_dim, self.i_c, self.i_h, self.i_w ) return encoder
[docs] class NodeVAEBuilderImgAlex(NodeVAEBuilderImgConvBnPool): """NodeVAEBuilderImgAlex"""
[docs] def is_myjob(self, request): """is_myjob. :param request: """ self.args = request.args flag = self.args.nname == "alexnet" # @FIXME self.config_img(flag, request) return flag
[docs] def build_encoder(self): """build_encoder.""" encoder = XYDEncoderParallelAlex( self.zd_dim, self.zx_dim, self.zy_dim, self.i_c, self.i_h, self.i_w, args=self.args, ) return encoder