Source code for domainlab.utils.u_import_net_module
"""
import external neural network implementation
"""
from domainlab.utils.logger import Logger
from domainlab.utils.u_import import import_path
[docs]
def build_external_obj_net_module_feat_extract(mpath, dim_y, remove_last_layer):
"""The user provide a function to initiate an object of the neural network,
which is fine for training but problematic for persistence of the trained
model since it is created externally.
:param mpath: path of external python file where the neural network
architecture is defined
:param dim_y: dimension of features
"""
# other possibility
# na_external_module = "name_external_module" # the dummy module name
# spec = importlib.util.spec_from_file_location(
# name=na_external_module,
# location=path_net_feat_extract)
# module_external = importlib.util.module_from_spec(spec)
# sys.modules[na_external_module] = module_external
# register the name of the external module
# spec.loader.exec_module(module_external)
net_module = import_path(mpath)
name_signature = "build_feat_extract_net(dim_y, \
remove_last_layer)"
# @FIXME: hard coded, move to top level __init__ definition in domainlab
name_fun = name_signature[: name_signature.index("(")]
if hasattr(net_module, name_fun):
try:
net = getattr(net_module, name_fun)(dim_y, remove_last_layer)
except Exception:
logger = Logger.get_logger()
logger.error(
f"function {name_signature} should return a neural network "
f"(pytorch module) that that extract features from an image"
)
raise
if net is None:
raise RuntimeError(
"the pytorch module returned by %s is None" % (name_signature)
)
else:
raise RuntimeError(
"Please implement a function %s \
in your external python file"
% (name_signature)
)
return net