Source code for domainlab.compos.builder_nn_conv_bn_relu_2
from domainlab.compos.a_nn_builder import AbstractFeatExtractNNBuilderChainNode
from domainlab.compos.nn_zoo.net_conv_conv_bn_pool_2 import NetConvBnReluPool2L
[docs]
def mkNodeFeatExtractNNBuilderNameConvBnRelu2(arg_name4net, arg_val, conv_stride):
"""mkNodeFeatExtractNNBuilderNameConvBnRelu2.
In chain of responsibility selection of neural network, reuse code to add
more possibilities of neural network of the same family.
:param arg_name4net: name of nn in args
:param arg_val: the registered name of the neural network to be added
:param conv_stride: should be 1 for 28*28 images
:param i_c:
:param i_h:
:param i_w:
"""
class _NodeFeatExtractNNBuilderConvBnRelu2L(AbstractFeatExtractNNBuilderChainNode):
"""NodeFeatExtractNNBuilderConvBnRelu2L."""
def init_business(
self, dim_out, args, isize, flag_pretrain=None, remove_last_layer=False
):
"""
:param flag_pretrain
"""
self.net_feat_extract = NetConvBnReluPool2L(
isize=isize, conv_stride=conv_stride, dim_out_h=dim_out
)
return self.net_feat_extract
def is_myjob(self, args):
"""is_myjob.
:param args: command line arguments:
"--nname": name of the torchvision model
"""
arg_name = getattr(args, arg_name4net)
return arg_name == arg_val
return _NodeFeatExtractNNBuilderConvBnRelu2L