"""
trainer matchdg
"""
import torch
from domainlab import g_inst_component_loss_agg, g_list_loss_agg
from domainlab.algos.trainers.a_trainer import AbstractTrainer
from domainlab.algos.trainers.compos.matchdg_match import MatchPair
from domainlab.algos.trainers.compos.matchdg_utils import (
dist_cosine_agg,
dist_pairwise_cosine,
get_base_domain_size4match_dg,
)
from domainlab.tasks.utils_task_dset import DsetIndDecorator4XYD
from domainlab.utils.logger import Logger
from domainlab.utils.hyperparameter_retrieval import get_gamma_reg
[docs]
class TrainerMatchDG(AbstractTrainer):
"""
Contrastive Learning
"""
[docs]
def dset_decoration_args_algo(self, args, ddset):
ddset = DsetIndDecorator4XYD(ddset)
return ddset
[docs]
def init_business(
self, model, task, observer, device, aconf, flag_accept=True, flag_erm=False
):
"""
initialize member objects
"""
super().init_business(model, task, observer, device, aconf, flag_accept)
# use the same batch size for match tensor
# so that order is kept!
self.base_domain_size = get_base_domain_size4match_dg(self.task)
self.epo_loss_tr = 0
self.flag_erm = flag_erm
self.lambda_ctr = get_gamma_reg(aconf, self.name)
self.mk_match_tensor(epoch=0)
self.flag_match_tensor_sweep_over = False
self.tuple_tensor_ref_domain2each_y = None
self.tuple_tensor_refdomain2each = None
[docs]
def tr_epoch(self, epoch):
"""
# data in one batch comes from two sources: one part from loader,
# the other part from match tensor
"""
self.model.train()
self.epo_loss_tr = 0
logger = Logger.get_logger()
# update match tensor
if (epoch + 1) % self.aconf.epos_per_match_update == 0:
self.mk_match_tensor(epoch)
inds_shuffle = torch.randperm(self.tensor_ref_domain2each_domain_x.size(0))
# NOTE: match tensor size: N(ref domain size) * #(train domains) * (image size: c*h*w)
# self.tensor_ref_domain2each_domain_x[inds_shuffle]
# shuffles the match tensor at the first dimension
self.tuple_tensor_refdomain2each = torch.split(
self.tensor_ref_domain2each_domain_x[inds_shuffle], self.aconf.bs, dim=0
)
# Splits the tensor into chunks.
# Each chunk is a view of the original tensor of batch size self.aconf.bs
# return is a tuple of the splited chunks
self.tuple_tensor_ref_domain2each_y = torch.split(
self.tensor_ref_domain2each_domain_y[inds_shuffle], self.aconf.bs, dim=0
)
logger.info(
f"number of batches in match tensor: {len(self.tuple_tensor_refdomain2each)}"
)
logger.info(
f"single batch match tensor size: {self.tuple_tensor_refdomain2each[0].shape}"
)
for batch_idx, (x_e, y_e, d_e, *others) in enumerate(self.loader_tr):
# random loader with same batch size as the match tensor loader
# the 4th output of self.loader is not used at all,
# is only used for creating the match tensor
self.tr_batch(epoch, batch_idx, x_e, y_e, d_e, others)
if self.flag_match_tensor_sweep_over is True:
logger.info(
"ref/base domain vs each domain match \
traversed one sweep, starting new epoch"
)
self.flag_match_tensor_sweep_over = False
break
if epoch < self.aconf.epochs_ctr:
logger.info("\n\nPhase ctr-only continue\n\n")
self.observer.reset()
return False
logger.info("\n\nPhase erm+ctr \n\n")
self.flag_erm = True
flag_stop = self.observer.update(epoch) # notify observer
return flag_stop
[docs]
def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None):
"""
update network for each batch
"""
self.optimizer.zero_grad()
x_e = x_e.to(self.device) # 64 * 1 * 224 * 224
# y_e_scalar = torch.argmax(y_e, dim=1).to(self.device)
y_e = y_e.to(self.device)
# d_e = torch.argmax(d_e, dim=1).numpy()
d_e = d_e.to(self.device)
# for each batch, the list loss is re-initialized
# CTR (contrastive) loss for CTR/ERM phase are different
list_batch_loss_ctr = []
# for a single batch, loss need to be
# aggregated across different combinations of domains.
# Defining a leaf node can cause problem by loss_ctr += xxx,
# a list with python built-in "sum" can aggregate
# these losses within one batch
if self.flag_erm:
# decoratee can be both trainer or model
list_loss_reg_rand, list_mu_reg = self.decoratee.cal_reg_loss(
x_e, y_e, d_e, others
)
loss_reg = self.model.list_inner_product(list_loss_reg_rand, list_mu_reg)
loss_task_rand = self.model.cal_task_loss(x_e, y_e)
# loss_erm_rnd_loader, *_ = self.model.cal_loss(x_e, y_e, d_e, others)
loss_erm_rnd_loader = (
loss_reg + loss_task_rand * self.model.multiplier4task_loss
)
num_batches_match_tensor = len(self.tuple_tensor_refdomain2each)
if batch_idx >= num_batches_match_tensor:
self.flag_match_tensor_sweep_over = True
return
curr_batch_size = self.tuple_tensor_refdomain2each[batch_idx].shape[0]
batch_tensor_ref_domain2each = self.tuple_tensor_refdomain2each[batch_idx].to(
self.device
)
# make order 5 tensor: (ref_domain, domain, channel, img_h, img_w)
# with first dimension as batch size
# clamp the first two dimensions so the model network could map image to feature
batch_tensor_ref_domain2each = match_tensor_reshape(
batch_tensor_ref_domain2each
)
# now batch_tensor_ref_domain2each first dim will not be batch_size!
# batch_tensor_ref_domain2each.shape torch.Size([40, channel, 224, 224])
batch_feat_ref_domain2each = self.model.extract_semantic_feat(
batch_tensor_ref_domain2each
)
# batch_feat_ref_domain2each.shape torch.Size[40, 512]
# torch.sum(torch.isnan(batch_tensor_ref_domain2each))
# assert not torch.sum(torch.isnan(batch_feat_ref_domain2each))
flag_isnan = torch.any(torch.isnan(batch_feat_ref_domain2each))
if flag_isnan:
logger = Logger.get_logger()
logger.info(batch_tensor_ref_domain2each)
raise RuntimeError(
"batch_feat_ref_domain2each NAN! is learning rate too big or"
"hyper-parameter tau not set appropriately?"
)
# for contrastive training phase,
# the last layer of the model is replaced with identity
batch_ref_domain2each_y = self.tuple_tensor_ref_domain2each_y[batch_idx].to(
self.device
)
batch_ref_domain2each_y = batch_ref_domain2each_y.view(
batch_ref_domain2each_y.shape[0] * batch_ref_domain2each_y.shape[1]
)
if self.flag_erm:
# @FIXME: check if batch_ref_domain2each_y is
# continuous number which means it is at its initial value,
# not yet filled
loss_erm_match_tensor, *_ = self.model.cal_task_loss(
batch_tensor_ref_domain2each, batch_ref_domain2each_y.long()
)
# Creating tensor of shape (domain size, total domains, feat size )
# The match tensor's first two dimension
# [(Ref domain size) * (# train domains)]
# has been clamped together to get features extracted
# through self.model
# it has to be reshaped into the match tensor shape, the same
# for the extracted feature here, it has to reshaped into
# the shape of the match tensor
# to make sure that the reshape only happens at the
# first two dimension, the feature dim has to be kept intact
dim_feat = batch_feat_ref_domain2each.shape[1]
num_domain_tr = len(self.task.list_domain_tr)
batch_feat_ref_domain2each = batch_feat_ref_domain2each.view(
curr_batch_size, num_domain_tr, dim_feat
)
batch_ref_domain2each_y = batch_ref_domain2each_y.view(
curr_batch_size, num_domain_tr
)
# The match tensor's first two dimension
# [(Ref domain size) * (# train domains)] has been clamped
# together to get features extracted through self.model
batch_tensor_ref_domain2each = batch_tensor_ref_domain2each.view(
curr_batch_size,
num_domain_tr,
batch_tensor_ref_domain2each.shape[1], # channel
batch_tensor_ref_domain2each.shape[2], # img_h
batch_tensor_ref_domain2each.shape[3],
) # img_w
# Contrastive Loss: class \times domain \times domain
counter_same_cls_diff_domain = 1
logger = Logger.get_logger()
for y_c in range(self.task.dim_y):
subset_same_cls = batch_ref_domain2each_y[:, 0] == y_c
subset_diff_cls = batch_ref_domain2each_y[:, 0] != y_c
feat_same_cls = batch_feat_ref_domain2each[subset_same_cls]
feat_diff_cls = batch_feat_ref_domain2each[subset_diff_cls]
logger.debug(
f"class {y_c} with same class and different class: "
+ f"{feat_same_cls.shape[0]} {feat_diff_cls.shape[0]}"
)
if feat_same_cls.shape[0] == 0 or feat_diff_cls.shape[0] == 0:
logger.debug(
f"no instances of label {y_c}" f"in the current batch, continue"
)
continue
if torch.sum(torch.isnan(feat_diff_cls)):
raise RuntimeError("feat_diff_cls has nan entrie(s)")
feat_diff_cls = feat_diff_cls.view(
feat_diff_cls.shape[0] * feat_diff_cls.shape[1], feat_diff_cls.shape[2]
)
for d_i in range(feat_same_cls.shape[1]):
dist_diff_cls_same_domain = dist_pairwise_cosine(
feat_same_cls[:, d_i, :], feat_diff_cls[:, :]
)
if torch.sum(torch.isnan(dist_diff_cls_same_domain)):
raise RuntimeError("dist_diff_cls_same_domain NAN")
# iterate other domains
for d_j in range(feat_same_cls.shape[1]):
if d_i >= d_j:
continue
dist_same_cls_diff_domain = dist_cosine_agg(
feat_same_cls[:, d_i, :], feat_same_cls[:, d_j, :]
)
if torch.sum(torch.isnan(dist_same_cls_diff_domain)):
raise RuntimeError("dist_same_cls_diff_domain NAN")
# CTR (contrastive) loss is exclusive for
# CTR phase and ERM phase
if self.flag_erm:
list_batch_loss_ctr.append(torch.sum(dist_same_cls_diff_domain))
else:
i_dist_same_cls_diff_domain = 1.0 - dist_same_cls_diff_domain
i_dist_same_cls_diff_domain = (
i_dist_same_cls_diff_domain / self.aconf.tau
)
partition = torch.log(
torch.exp(i_dist_same_cls_diff_domain)
+ dist_diff_cls_same_domain
)
list_batch_loss_ctr.append(
-1 * torch.sum(i_dist_same_cls_diff_domain - partition)
)
counter_same_cls_diff_domain += dist_same_cls_diff_domain.shape[0]
loss_ctr = g_list_loss_agg(list_batch_loss_ctr) / counter_same_cls_diff_domain
if self.flag_erm:
epos = self.aconf.epos
else:
epos = self.aconf.epochs_ctr
percentage_finished_epochs = (epoch + 1) / (epos + 1)
# loss aggregation is over different domain
# combinations of the same batch
# https://discuss.pytorch.org/t/leaf-variable-was-used-in-an-inplace-operation/308
# Loosely, tensors you create directly are leaf variables.
# Tensors that are the result of a differentiable operation are
# not leaf variables
if self.flag_erm:
# extra loss of ERM phase: the ERM loss
# (the CTR loss for the ctr phase and erm phase are different)
# erm loss comes from two different data loaders,
# one is rnd (random) data loader
# the other one is the data loader from the match tensor
loss_e = (
torch.tensor(0.0, requires_grad=True)
+ g_inst_component_loss_agg(loss_erm_rnd_loader)
+ g_inst_component_loss_agg(loss_erm_match_tensor)
* self.model.multiplier4task_loss
+ self.lambda_ctr * percentage_finished_epochs * loss_ctr
)
else:
loss_e = (
torch.tensor(0.0, requires_grad=True)
+ self.lambda_ctr * percentage_finished_epochs * loss_ctr
)
# @FIXME: without torch.tensor(0.0), after a few epochs,
# error "'float' object has no attribute 'backward'"
loss_e.backward(retain_graph=False)
self.optimizer.step()
self.epo_loss_tr += loss_e.detach().item()
torch.cuda.empty_cache()
[docs]
def mk_match_tensor(self, epoch):
"""
initialize or update match tensor
"""
obj_match = MatchPair(
self.task.dim_y,
self.task.isize.i_c,
self.task.isize.i_h,
self.task.isize.i_w,
self.aconf.bs,
virtual_ref_dset_size=self.base_domain_size,
num_domains_tr=len(self.task.list_domain_tr),
list_tr_domain_size=self.list_tr_domain_size,
)
# @FIXME: what is the usefulness of (epoch > 0) as argument
(
self.tensor_ref_domain2each_domain_x,
self.tensor_ref_domain2each_domain_y,
) = obj_match(
self.device,
self.task.loader_tr,
self.model.extract_semantic_feat,
(epoch > 0),
)
[docs]
def before_tr(self):
"""
override abstract method
"""
logger = Logger.get_logger()
logger.info("\n\nPhase 1 start: contractive alignment without task loss: \n\n")
# phase 1: contrastive learning
# different than phase 2, ctr_model has no classification loss
[docs]
def match_tensor_reshape(batch_tensor_ref_domain2each):
"""
# original dimension is (ref_domain, domain, (channel, img_h, img_w))
# use a function so it is easier to accomodate other data mode (not image)
"""
batch_tensor_refdomain_other_domain_chw = batch_tensor_ref_domain2each.view(
batch_tensor_ref_domain2each.shape[0] * batch_tensor_ref_domain2each.shape[1],
batch_tensor_ref_domain2each.shape[2], # channel
batch_tensor_ref_domain2each.shape[3], # img_h
batch_tensor_ref_domain2each.shape[4],
) # img_w
return batch_tensor_refdomain_other_domain_chw