Source code for domainlab.compos.vae.zoo_vae_builders_classif
"""
Chain node VAE builders
"""
from domainlab.compos.vae.c_vae_builder_classif import (
ChainNodeVAEBuilderClassifCondPrior,
)
from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import (
DecoderConcatLatentFCReshapeConvGatedConv,
)
from domainlab.compos.vae.compos.encoder_xyd_parallel import (
XYDEncoderParallelAlex,
XYDEncoderParallelConvBnReluPool,
XYDEncoderParallelExtern,
XYDEncoderParallelUser,
)
[docs]
class ChainNodeVAEBuilderClassifCondPriorBase(ChainNodeVAEBuilderClassifCondPrior):
"""
base class of AE builder
"""
[docs]
def config_img(self, flag, request):
"""config_img.
:param flag:
:param request:
"""
if flag:
self.i_c = request.i_c
self.i_h = request.i_h
self.i_w = request.i_w
[docs]
def is_myjob(self, request):
"""is_myjob.
:param request:
"""
raise NotImplementedError
[docs]
def build_encoder(self):
"""build_encoder."""
raise NotImplementedError
[docs]
def build_decoder(self):
"""build_decoder."""
decoder = DecoderConcatLatentFCReshapeConvGatedConv(
z_dim=self.zd_dim + self.zx_dim + self.zy_dim,
i_c=self.i_c,
i_w=self.i_w,
i_h=self.i_h,
)
return decoder
[docs]
class NodeVAEBuilderArg(ChainNodeVAEBuilderClassifCondPriorBase):
"""Build encoder decoder according to commandline arguments"""
[docs]
def is_myjob(self, request):
"""is_myjob.
:param request:
"""
self.request = request
self.args = request.args
self.config_img(True, request)
if self.args.npath is not None or self.args.npath_dom is not None:
return True
return False
[docs]
def build_encoder(self):
"""build_encoder."""
encoder = XYDEncoderParallelExtern(
self.zd_dim,
self.zx_dim,
self.zy_dim,
args=self.args,
i_c=self.i_c,
i_h=self.i_h,
i_w=self.i_w,
)
return encoder
[docs]
class NodeVAEBuilderUser(ChainNodeVAEBuilderClassifCondPriorBase):
"""Build encoders according to test_mk_exp file"""
[docs]
def is_myjob(self, request):
flag = not hasattr(request, "args")
self.request = request
self.config_img(flag, request)
return flag
[docs]
def build_encoder(self):
encoder = XYDEncoderParallelUser(
self.request.net_class_d, self.request.net_x, self.request.net_class_y
)
return encoder
[docs]
class NodeVAEBuilderImgConvBnPool(ChainNodeVAEBuilderClassifCondPriorBase):
[docs]
def is_myjob(self, request):
"""is_myjob.
:param request:
"""
flag = (
request.args.nname == "conv_bn_pool_2"
or request.args.nname_dom == "conv_bn_pool_2"
) # @FIXME
self.config_img(flag, request)
return flag
[docs]
def build_encoder(self):
"""build_encoder."""
encoder = XYDEncoderParallelConvBnReluPool(
self.zd_dim, self.zx_dim, self.zy_dim, self.i_c, self.i_h, self.i_w
)
return encoder
[docs]
class NodeVAEBuilderImgAlex(NodeVAEBuilderImgConvBnPool):
"""NodeVAEBuilderImgAlex"""
[docs]
def is_myjob(self, request):
"""is_myjob.
:param request:
"""
self.args = request.args
flag = self.args.nname == "alexnet" # @FIXME
self.config_img(flag, request)
return flag
[docs]
def build_encoder(self):
"""build_encoder."""
encoder = XYDEncoderParallelAlex(
self.zd_dim,
self.zx_dim,
self.zy_dim,
self.i_c,
self.i_h,
self.i_w,
args=self.args,
)
return encoder