Source code for domainlab
"""
globals for the whole package
"""
__docformat__ = "restructuredtext"
import torch
g_inst_component_loss_agg = torch.sum
g_tensor_batch_agg = torch.sum
g_list_loss_agg = sum
g_name_num_shared_param_samples_rand_search = "num_shared_param_samples"
[docs]
def g_list_model_penalized_reg_agg(list_penalized_reg):
"""
aggregate along the list, but do not diminish the batch structure of the tensor
"""
return torch.stack(list_penalized_reg, dim=0).sum(dim=0)
g_str_cross_entropy_agg = "none"
# component loss refers to aggregation of pixel loss, digit of KL divergences loss
# instance loss currently use torch.sum, which is the same effect as torch.mean, the
# important part is the component aggregation method inside a single instance