Source code for domainlab.compos.builder_nn_alex
from domainlab.compos.a_nn_builder import AbstractFeatExtractNNBuilderChainNode
from domainlab.compos.nn_zoo.nn_alex import Alex4DeepAll, AlexNetNoLastLayer
[docs]
def mkNodeFeatExtractNNBuilderNameAlex(arg_name4net, arg_val):
class NodeFeatExtractNNBuilderAlex(AbstractFeatExtractNNBuilderChainNode):
"""NodeFeatExtractNNBuilderAlex.
Uniform interface to return AlexNet and other neural network as feature
extractor from torchvision or external python file"""
def init_business(
self, dim_out, args, isize=None, remove_last_layer=False, flag_pretrain=True
):
"""
initialize **and** return the heavy weight business
object for doing the real job
:param request: subclass can override request object
to be string or function
:return: the constructed service object
i_size is not used at all in this class
"""
if not remove_last_layer:
self.net_feat_extract = Alex4DeepAll(flag_pretrain, dim_out)
else:
self.net_feat_extract = AlexNetNoLastLayer(flag_pretrain)
return self.net_feat_extract
def is_myjob(self, args):
"""is_myjob.
:param args: command line arguments:
"--nname": name of the torchvision model
"""
arg_name = getattr(args, arg_name4net)
return arg_name == arg_val
return NodeFeatExtractNNBuilderAlex