domainlab.algos.trainers package



domainlab.algos.trainers.a_trainer module

Base Class for trainer

class domainlab.algos.trainers.a_trainer.AbstractTrainer(successor_node=None, extend=None)[source]

Bases: AbstractChainNodeHandler

Algorithm director that controls the data flow

after_batch(epoch, ind_batch)[source]
  • epoch

  • ind_batch

before_batch(epoch, ind_batch)[source]
  • epoch

  • ind_batch


before training, probe model performance

cal_reg_loss(tensor_x, tensor_y, tensor_d, others=None)[source]

decorate trainer regularization loss combine losses of current trainer with self._model.cal_reg_loss, which can be either a trainer or a model


estimate the scale of each loss term, match each loss term to the major loss via a ratio, this ratio will be multiplied with multiplier

property decoratee
dset_decoration_args_algo(args, ddset)[source]

decorate dataset to get extra entries in load item, for instance, jigen need permutation index this parent class function delegate decoration to its decoratee


extend current trainer with another trainer


recursively get the “real” model from trainer

init_business(model, task, observer, device, aconf, flag_accept=True)[source]

model, task, observer, device, aconf


request – string

property list_tr_domain_size

get a list of training domain size

property model

property model, which can be another trainer or model

property name

get the name of the algorithm

property p_na_prefix

common prefix for Trainers


after training


Function to print all parameters of the object. Can be used to print the parameters of any child class


make a new optimizer to clear internal state

property str_metric4msel

metric for model selection

abstract tr_epoch(epoch)[source]


domainlab.algos.trainers.a_trainer.mk_opt(model, aconf)[source]

create optimizer

domainlab.algos.trainers.args_dial module

domain invariant adversarial trainer hyper-parmaeters


append hyper-parameters to the main argparser

domainlab.algos.trainers.args_miro module

miro trainer configurations


append hyper-parameters to the main argparser

domainlab.algos.trainers.hyper_scheduler module

update hyper-parameters during training

class domainlab.algos.trainers.hyper_scheduler.HyperSchedulerWarmupExponential(trainer, **kwargs)[source]

Bases: HyperSchedulerWarmupLinear

HyperScheduler Exponential

warmup(par_setpoint, epoch)[source]

start from a small value of par to ramp up the steady state value using number of total_steps :param epoch:

class domainlab.algos.trainers.hyper_scheduler.HyperSchedulerWarmupLinear(trainer, **kwargs)[source]

Bases: object


set number of total_steps to gradually change optimization parameter

warmup(par_setpoint, epoch)[source]

warmup. start from a small value of par to ramp up the steady state value using # total_steps :param epoch:

domainlab.algos.trainers.mmd_base module

Alexej, Xudong

class domainlab.algos.trainers.mmd_base.TrainerMMDBase(successor_node=None, extend=None)[source]

Bases: TrainerBasic

causal matching

gaussian_kernel(x, y)[source]

kernel for MMD

mmd(x, y)[source]

maximum mean discrepancy

my_cdist(x1, x2)[source]

distance for Gaussian

domainlab.algos.trainers.train_basic module

basic trainer

class domainlab.algos.trainers.train_basic.TrainerBasic(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

basic trainer


observer collect information


set model to train mode initialize some member variables


check the performance of randomly initialized weight

cal_loss(tensor_x, tensor_y, tensor_d, others)[source]

so that user api can use trainer.cal_loss to train

log_loss(list_b_reg_loss, loss_task, loss)[source]

just for logging the self.epo_reg_loss_tr

tr_batch(tensor_x, tensor_y, tensor_d, others, ind_batch, epoch)[source]

optimize neural network one step upon a mini-batch of data



domainlab.algos.trainers.train_basic.list_divide(list_val, scalar)[source]

domainlab.algos.trainers.train_causIRL module

Alex, Xudong

class domainlab.algos.trainers.train_causIRL.TrainerCausalIRL(successor_node=None, extend=None)[source]

Bases: TrainerBasic

causal matching

gaussian_kernel(x, y)[source]

kernel for MMD

mmd(x, y)[source]

maximum mean discrepancy

my_cdist(x1, x2)[source]

distance for Gaussian

tr_batch(tensor_x, tensor_y, tensor_d, others, ind_batch, epoch)[source]

optimize neural network one step upon a mini-batch of data

domainlab.algos.trainers.train_coral module

Deep CORAL: Correlation Alignment for Deep Domain Adaptation [au] Alexej, Xudong

class domainlab.algos.trainers.train_coral.TrainerCoral(successor_node=None, extend=None)[source]

Bases: TrainerMMDBase

cross domain MMD


domain-pairwise mmd



domainlab.algos.trainers.train_dial module

use random start to generate adversarial images

class domainlab.algos.trainers.train_dial.TrainerDIAL(successor_node=None, extend=None)[source]

Bases: TrainerBasic

Trainer Domain Invariant Adversarial Learning

gen_adversarial(device, img_natural, vec_y)[source]

use naive trimming to find optimize img in the direction of adversarial gradient, this is not necessarily constraint optimal due to nonlinearity, as the constraint epsilon is only considered ad-hoc

domainlab.algos.trainers.train_ema module

simple exponential moving average of each layers, after each epoch, trainer=ma_trainer2_trainer3 always set ma to be outer most

Paper: Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization Devansh Arpit, Huan Wang, Yingbo Zhou, Caiming Xiong Salesforce Research, USA

class domainlab.algos.trainers.train_ema.TrainerMA(successor_node=None, extend=None)[source]

Bases: TrainerBasic

initializer of this class goes to one block/section in the abstract class initializer, otherwise it will break the class inheritance.


observer collect information

move_average(dict_data, epoch)[source]

for each epoch, convex combine the weights for each layer Paper: Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization Devansh Arpit, Huan Wang, Yingbo Zhou, Caiming Xiong Salesforce Research, USA

domainlab.algos.trainers.train_fishr module

use random start to generate adversarial images

class domainlab.algos.trainers.train_fishr.TrainerFishr(successor_node=None, extend=None)[source]

Bases: TrainerBasic

The goal is to minimize the variance of the domain-level variance of the gradients. This aligns the domain-level loss landscapes locally around the final weights, reducing inconsistencies across domains.

For more details, see: Alexandre Ramé, Corentin Dancette, and Matthieu Cord.

“Fishr: Invariant gradient variances for out-of-distribution generalization.” International Conference on Machine Learning. PMLR, 2022.

cal_dict_variance_grads(tensor_x, vec_y)[source]

Calculates the domain-level variances of the gradients w.r.t. the scalar component of the weight tensor for the layer in question, i.e. $$v_i = var(nabla_{theta}ell(x^{(d_i)}, y^{(d_i)}))$$, where $$d_i$$ means data coming from domain i. The computation is done using the package backpack.

Input: tensor_x, a tensor, where the first dimension is the batch size and vec_y, which is a vector representing the output labels.

Return: dictionary, where the key is the name for the layer of a neural network and the value is the diagonal variance of each scalar component of the gradient of the loss w.r.t. the parameter.

Return Example: {“layer1”: Tensor[batchsize=32, 64, 3, 11, 11 ]} as a convolution kernel


Calculates the mean across several dictionaries. Input: list of dictionaries, where the values of each dictionary are tensors. Return: dictionary, where the values are tensors. The scalar values of the tensors contain the mean across the first dimension of the dictionaries from the list of inputs.


Calculates the element-wise power of the values in a dictionary, when the values ar tensors. Input: dictionary, where the values are tensors. Return: dictionary, where the values are tensors. The scalar values of the tensors are the element-wise power of the scalars in the input dictionary.




Calculate the domain-level variance of the gradients and the layer-wise erm loss. Input: a tupel containing lists with the data per domain Return: two lists. The first one contains dictionaries with the gradient variances. The keys are the layers and the values are tensors. The gradient variances are stored in the tensors. The second list contains the losses. Each list entry represents the summed up erm loss of a single layer.


Computes the variance of the domain-level gradient variances, layer-wise. Let $v=1/nsum_i^n v_i represent the mean across n domains, with $$v_i = var(nabla_{theta}ell(x^{(d_i)}, y^{(d_i)}))$$, where $$d_i$$ means data coming from domain i. We are interested in $1/nsum_(v_i-v)^2=1/n sum_i v_i^2 - v^2$.

Input: list of dictionaries, each dictionary has the structure {“layer1”: tensor[64, 3, 11, 11], “layer2”: tensor[8, 3, 5, 5]}….. The scalar values in the dictionary are the variances of the gradient of the loss w.r.t. the scalar component of the weight tensor for the layer in question, where the variance is computed w.r.t. the minibatch of a particular domain.

Return: dictionary, containing the layers as keys and tensors as values. The variances are stored in the tensors as scalars.

domainlab.algos.trainers.train_hyper_scheduler module

update hyper-parameters during training

class domainlab.algos.trainers.train_hyper_scheduler.TrainerHyperScheduler(successor_node=None, extend=None)[source]

Bases: TrainerBasic

before_batch(epoch, ind_batch)[source]

if hyper-parameters should be updated per batch, then step should be set to epoch*self.num_batches + ind_batch


check the performance of randomly initialized weight

set_scheduler(scheduler, total_steps, flag_update_epoch=False, flag_update_batch=False)[source]

set the warmup strategy from objective scheduler set wheter the hyper-parameter scheduling happens per epoch or per batch

  • scheduler – The class name of the scheduler, the object corresponding to

  • model (this class name will be created inside) –

  • total_steps – number of steps to change the hyper-parameters

  • flag_update_epoch – if hyper-parameters should be changed per epoch

  • flag_update_batch – if hyper-parameters should be changed per batch


update hyper-parameters only per epoch

domainlab.algos.trainers.train_irm module

use random start to generate adversarial images

class domainlab.algos.trainers.train_irm.TrainerIRM(successor_node=None, extend=None)[source]

Bases: TrainerBasic

IRMv1 split a minibatch into half, and use an unbiased estimate of the squared gradient norm via inner product $$delta_{w|w=1} ell(wdot Phi(X^{e, i}), Y^{e, i})$$ of dimension dim(Grad) with $$delta_{w|w=1} ell(wdot Phi(X^{e, j}), Y^{e, j})$$ of dimension dim(Grad) For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.”



domainlab.algos.trainers.train_matchdg module

trainer matchdg

class domainlab.algos.trainers.train_matchdg.TrainerMatchDG(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

Contrastive Learning


override abstract method

dset_decoration_args_algo(args, ddset)[source]

decorate dataset to get extra entries in load item, for instance, jigen need permutation index this parent class function delegate decoration to its decoratee

init_business(model, task, observer, device, aconf, flag_accept=True, flag_erm=False)[source]

initialize member objects


initialize or update match tensor

tr_batch(epoch, batch_idx, x_e, y_e, d_e, others=None)[source]

update network for each batch


# data in one batch comes from two sources: one part from loader, # the other part from match tensor


# original dimension is (ref_domain, domain, (channel, img_h, img_w)) # use a function so it is easier to accomodate other data mode (not image)

domainlab.algos.trainers.train_miro module

author: Kakao Brain. # # [aut] xudong, alexej

class domainlab.algos.trainers.train_miro.TrainerMiro(successor_node=None, extend=None)[source]

Bases: TrainerBasic

Mutual-Information Regularization with Oracle


check the performance of randomly initialized weight

domainlab.algos.trainers.train_miro_model_wraper module

class domainlab.algos.trainers.train_miro_model_wraper.TrainerMiroModelWraper[source]

Bases: object

Mutual-Information Regularization with Oracle

accept(guest_model, name_feat_layers2extract=None)[source]
cal_feat_layers_ref_model(tensor_x, tensor_y, tensor_d, others=None)[source]
extract_intermediate_features(tensor_x, tensor_y, tensor_d, others=None)[source]

extract features for each layer of the neural network

hook(module, input, output)[source]
hook_ref(module, input, output)[source]

domainlab.algos.trainers.train_miro_utils module

Laplace approximation for Mutual Information estimation

class domainlab.algos.trainers.train_miro_utils.MeanEncoder(inter_layer_feat_shape)[source]

Bases: Module

Identity function


Defines the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class domainlab.algos.trainers.train_miro_utils.VarianceEncoder(inter_layer_feat_shape, init=0.1, eps=1e-05)[source]

Bases: Module

Bias-only model with diagonal covariance


train batch(population) level variance

domainlab.algos.trainers.train_mldg module

Meta Learning Domain Generalization

class domainlab.algos.trainers.train_mldg.TrainerMLDG(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

basic trainer


check the performance of randomly initialized weight


create virtual source and target domain



domainlab.algos.trainers.zoo_trainer module

select trainer

class domainlab.algos.trainers.zoo_trainer.TrainerChainNodeGetter(str_trainer)[source]

Bases: object

Chain of Responsibility: node is named in pattern Trainer[XXX] where the string after ‘Trainer’ is the name to be passed to args.trainer.

Module contents