[docs]
class AModel(nn.Module, metaclass=abc.ABCMeta):
"""
operations that all models (classification, segmentation, seq2seq)
"""
def __init__(self):
super().__init__()
self._decoratee = None
self.list_d_tr = None
self.visitor = None
self._net_invar_feat = None
[docs]
def extend(self, model):
"""
extend the loss of the decoratee
"""
self._decoratee = model
self.reset_feature_extractor(model.net_invar_feat)
@property
def metric4msel(self):
"""
metric for model selection
"""
raise NotImplementedError
@property
def multiplier4task_loss(self):
"""
the multiplier for task loss is default to 1 except for vae family models
"""
return 1.0
[docs]
def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
"""
calculate the loss
"""
list_loss, list_multiplier = self.cal_reg_loss(
tensor_x, tensor_y, tensor_d, others
)
loss_reg = self.list_inner_product(list_loss, list_multiplier)
loss_task_alone = self.cal_task_loss(tensor_x, tensor_y)
loss_task = self.multiplier4task_loss * loss_task_alone
return loss_task + loss_reg, list_loss, loss_task_alone
[docs]
def list_inner_product(self, list_loss, list_multiplier):
"""
compute inner product between list of regularization loss and multiplier
- the length of the list is the number of regularizers
- for each element of the list: the first dimension of the tensor is mini-batch
return value of list_inner_product should keep the minibatch structure, thus aggregation
here only aggregate along the list
"""
list_tuple = zip(list_loss, list_multiplier)
list_penalized_reg = [mtuple[0] * mtuple[1] for mtuple in list_tuple]
tensor_batch_penalized_loss = g_list_model_penalized_reg_agg(list_penalized_reg)
# return value of list_inner_product should keep the minibatch structure, thus aggregation
# here only aggregate along the list
return tensor_batch_penalized_loss
[docs]
@abc.abstractmethod
def cal_task_loss(self, tensor_x, tensor_y):
"""
Calculate the task loss
:param tensor_x: input
:param tensor_y: label
:return: task loss
"""
@abc.abstractmethod
def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
task independent regularization loss for domain generalization
"""
[docs]
def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
task independent regularization loss for domain generalization
"""
loss_reg, mu = self._extend_loss(tensor_x, tensor_y, tensor_d, others)
loss_reg_, mu_ = self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
if loss_reg is not None:
return loss_reg_ + loss_reg, mu_ + mu
return loss_reg_, mu_
def _extend_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
combine losses from two models
"""
if self._decoratee is not None:
return self._decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
return None, None
[docs]
def forward(self, tensor_x):
"""forward.
:param x:
:param y:
:param d:
"""
out = self.extract_semantic_feat(tensor_x)
return out
@property
def net_invar_feat(self):
"""
if exist, return a neural network for extracting invariant features
"""
return self._net_invar_feat
[docs]
def reset_aux_net(self):
"""
after feature extractor being reset, the input dim of other networks like domain
classification will also change (for commandline usage only)
"""
# by default doing nothing
[docs]
def save(self, suffix=None):
"""
persist model to disk
"""
if self.visitor is None:
return
self.visitor.save(self, suffix)
return
[docs]
def load(self, suffix=None):
"""
load model from disk
"""
if self.visitor is None:
return None
return self.visitor.load(suffix)
[docs]
def set_saver(self, visitor):
self.visitor = visitor
[docs]
def dset_decoration_args_algo(self, args, ddset):
"""
decorate dataset to get extra entries in load item, for instance, jigen need permutation index
this parent class function delegate decoration to its decoratee
"""
if self._decoratee is not None:
return self._decoratee.dset_decoration_args_algo(args, ddset)
return ddset
@property
def p_na_prefix(self):
"""
common prefix for Models
"""
return "Model"
@property
def name(self):
"""
get the name of the algorithm
"""
na_prefix = self.p_na_prefix
len_prefix = len(na_prefix)
na_class = type(self).__name__
if na_class[:len_prefix] != na_prefix:
raise RuntimeError(
"Model builder node class must start with ",
na_prefix,
"the current class is named: ",
na_class,
)
return type(self).__name__[len_prefix:].lower()
[docs]
def print_parameters(self):
"""
Function to print all parameters of the object.
Can be used to print the parameters of every child class.
"""
params = vars(self)
print(f"Parameters of {type(self).__name__}: {params}")