domainlab.algos.trainers package

Subpackages

Submodules

domainlab.algos.trainers.a_trainer module

Base Class for trainer

class domainlab.algos.trainers.a_trainer.AbstractTrainer(successor_node=None, extend=None)[source]

Bases: AbstractChainNodeHandler

Algorithm director that controls the data flow

after_batch(epoch, ind_batch)[source]
Parameters:
  • epoch

  • ind_batch

before_batch(epoch, ind_batch)[source]
Parameters:
  • epoch

  • ind_batch

before_tr()[source]

before training, probe model performance

cal_reg_loss(tensor_x, tensor_y, tensor_d, others=None)[source]

decorate trainer regularization loss combine losses of current trainer with self._model.cal_reg_loss, which can be either a trainer or a model

cal_reg_loss_over_task_loss_ratio()[source]

estimate the scale of each loss term, match each loss term to the major loss via a ratio, this ratio will be multiplied with multiplier

property decoratee
dset_decoration_args_algo(args, ddset)[source]

decorate dataset to get extra entries in load item, for instance, jigen need permutation index this parent class function delegate decoration to its decoratee

extend(trainer)[source]

extend current trainer with another trainer

get_model()[source]

recursively get the “real” model from trainer

init_business(model, task, observer, device, aconf, flag_accept=True)[source]

model, task, observer, device, aconf

is_myjob(request)[source]
Parameters:

request – string

property list_tr_domain_size

get a list of training domain size

property model

property model, which can be another trainer or model

property name

get the name of the algorithm

property p_na_prefix

common prefix for Trainers

post_tr()[source]

after training

print_parameters()[source]

Function to print all parameters of the object. Can be used to print the parameters of any child class

reset()[source]

make a new optimizer to clear internal state

property str_metric4msel

metric for model selection

abstract tr_epoch(epoch)[source]
Parameters:

epoch

domainlab.algos.trainers.a_trainer.mk_opt(model, aconf)[source]

create optimizer

domainlab.algos.trainers.args_dial module

domain invariant adversarial trainer hyper-parmaeters

domainlab.algos.trainers.args_dial.add_args2parser_dial(parser)[source]

append hyper-parameters to the main argparser

domainlab.algos.trainers.args_miro module

miro trainer configurations

domainlab.algos.trainers.args_miro.add_args2parser_miro(parser)[source]

append hyper-parameters to the main argparser

domainlab.algos.trainers.hyper_scheduler module

update hyper-parameters during training

class domainlab.algos.trainers.hyper_scheduler.HyperSchedulerWarmupExponential(trainer, **kwargs)[source]

Bases: HyperSchedulerWarmupLinear

HyperScheduler Exponential

warmup(par_setpoint, epoch)[source]

start from a small value of par to ramp up the steady state value using number of total_steps :param epoch:

class domainlab.algos.trainers.hyper_scheduler.HyperSchedulerWarmupLinear(trainer, **kwargs)[source]

Bases: object

set_steps(total_steps)[source]

set number of total_steps to gradually change optimization parameter

warmup(par_setpoint, epoch)[source]

warmup. start from a small value of par to ramp up the steady state value using # total_steps :param epoch:

domainlab.algos.trainers.mmd_base module

Alexej, Xudong

class domainlab.algos.trainers.mmd_base.TrainerMMDBase(successor_node=None, extend=None)[source]

Bases: TrainerBasic

causal matching

gaussian_kernel(x, y)[source]

kernel for MMD

mmd(x, y)[source]

maximum mean discrepancy

my_cdist(x1, x2)[source]

distance for Gaussian

domainlab.algos.trainers.train_basic module

basic trainer

class domainlab.algos.trainers.train_basic.TrainerBasic(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

basic trainer

after_epoch(epoch)[source]

observer collect information

before_epoch()[source]

set model to train mode initialize some member variables

before_tr()[source]

check the performance of randomly initialized weight

cal_loss(tensor_x, tensor_y, tensor_d, others)[source]

so that user api can use trainer.cal_loss to train

log_loss(list_b_reg_loss, loss_task, loss)[source]

just for logging the self.epo_reg_loss_tr

tr_batch(tensor_x, tensor_y, tensor_d, others, ind_batch, epoch)[source]

optimize neural network one step upon a mini-batch of data

tr_epoch(epoch)[source]
Parameters:

epoch

domainlab.algos.trainers.train_basic.list_divide(list_val, scalar)[source]

domainlab.algos.trainers.train_causIRL module

Alex, Xudong

class domainlab.algos.trainers.train_causIRL.TrainerCausalIRL(successor_node=None, extend=None)[source]

Bases: TrainerBasic

causal matching

gaussian_kernel(x, y)[source]

kernel for MMD

mmd(x, y)[source]

maximum mean discrepancy

my_cdist(x1, x2)[source]

distance for Gaussian

tr_batch(tensor_x, tensor_y, tensor_d, others, ind_batch, epoch)[source]

optimize neural network one step upon a mini-batch of data

domainlab.algos.trainers.train_coral module

Deep CORAL: Correlation Alignment for Deep Domain Adaptation [au] Alexej, Xudong

class domainlab.algos.trainers.train_coral.TrainerCoral(successor_node=None, extend=None)[source]

Bases: TrainerMMDBase

cross domain MMD

cross_domain_mmd(tuple_data_domains_batch)[source]

domain-pairwise mmd

tr_epoch(epoch)[source]
Parameters:

epoch

domainlab.algos.trainers.train_dial module

use random start to generate adversarial images

class domainlab.algos.trainers.train_dial.TrainerDIAL(successor_node=None, extend=None)[source]

Bases: TrainerBasic

Trainer Domain Invariant Adversarial Learning

gen_adversarial(device, img_natural, vec_y)[source]

use naive trimming to find optimize img in the direction of adversarial gradient, this is not necessarily constraint optimal due to nonlinearity, as the constraint epsilon is only considered ad-hoc

domainlab.algos.trainers.train_ema module

simple exponential moving average of each layers, after each epoch, trainer=ma_trainer2_trainer3 always set ma to be outer most

Paper: Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization Devansh Arpit, Huan Wang, Yingbo Zhou, Caiming Xiong Salesforce Research, USA

class domainlab.algos.trainers.train_ema.TrainerMA(successor_node=None, extend=None)[source]

Bases: TrainerBasic

initializer of this class goes to one block/section in the abstract class initializer, otherwise it will break the class inheritance.

after_epoch(epoch)[source]

observer collect information

move_average(dict_data, epoch)[source]

for each epoch, convex combine the weights for each layer Paper: Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization Devansh Arpit, Huan Wang, Yingbo Zhou, Caiming Xiong Salesforce Research, USA

domainlab.algos.trainers.train_fishr module

use random start to generate adversarial images

class domainlab.algos.trainers.train_fishr.TrainerFishr(successor_node=None, extend=None)[source]

Bases: TrainerBasic

The goal is to minimize the variance of the domain-level variance of the gradients. This aligns the domain-level loss landscapes locally around the final weights, reducing inconsistencies across domains.

For more details, see: Alexandre Ramé, Corentin Dancette, and Matthieu Cord.

“Fishr: Invariant gradient variances for out-of-distribution generalization.” International Conference on Machine Learning. PMLR, 2022.

cal_dict_variance_grads(tensor_x, vec_y)[source]

Calculates the domain-level variances of the gradients w.r.t. the scalar component of the weight tensor for the layer in question, i.e. $$v_i = var(nabla_{theta}ell(x^{(d_i)}, y^{(d_i)}))$$, where $$d_i$$ means data coming from domain i. The computation is done using the package backpack.

Input: tensor_x, a tensor, where the first dimension is the batch size and vec_y, which is a vector representing the output labels.

Return: dictionary, where the key is the name for the layer of a neural network and the value is the diagonal variance of each scalar component of the gradient of the loss w.r.t. the parameter.

Return Example: {“layer1”: Tensor[batchsize=32, 64, 3, 11, 11 ]} as a convolution kernel

cal_mean_across_dict(list_dict)[source]

Calculates the mean across several dictionaries. Input: list of dictionaries, where the values of each dictionary are tensors. Return: dictionary, where the values are tensors. The scalar values of the tensors contain the mean across the first dimension of the dictionaries from the list of inputs.

cal_power_single_dict(mdict)[source]

Calculates the element-wise power of the values in a dictionary, when the values ar tensors. Input: dictionary, where the values are tensors. Return: dictionary, where the values are tensors. The scalar values of the tensors are the element-wise power of the scalars in the input dictionary.

tr_epoch(epoch)[source]
Parameters:

epoch

var_grads_and_loss(tuple_data_domains_batch)[source]

Calculate the domain-level variance of the gradients and the layer-wise erm loss. Input: a tupel containing lists with the data per domain Return: two lists. The first one contains dictionaries with the gradient variances. The keys are the layers and the values are tensors. The gradient variances are stored in the tensors. The second list contains the losses. Each list entry represents the summed up erm loss of a single layer.

variance_between_dict(list_dict_var_paragrad)[source]

Computes the variance of the domain-level gradient variances, layer-wise. Let $v=1/nsum_i^n v_i represent the mean across n domains, with $$v_i = var(nabla_{theta}ell(x^{(d_i)}, y^{(d_i)}))$$, where $$d_i$$ means data coming from domain i. We are interested in $1/nsum_(v_i-v)^2=1/n sum_i v_i^2 - v^2$.

Input: list of dictionaries, each dictionary has the structure {“layer1”: tensor[64, 3, 11, 11], “layer2”: tensor[8, 3, 5, 5]}….. The scalar values in the dictionary are the variances of the gradient of the loss w.r.t. the scalar component of the weight tensor for the layer in question, where the variance is computed w.r.t. the minibatch of a particular domain.

Return: dictionary, containing the layers as keys and tensors as values. The variances are stored in the tensors as scalars.

domainlab.algos.trainers.train_hyper_scheduler module

update hyper-parameters during training

class domainlab.algos.trainers.train_hyper_scheduler.TrainerHyperScheduler(successor_node=None, extend=None)[source]

Bases: TrainerBasic

before_batch(epoch, ind_batch)[source]

if hyper-parameters should be updated per batch, then step should be set to epoch*self.num_batches + ind_batch

before_tr()[source]

check the performance of randomly initialized weight

set_scheduler(scheduler, total_steps, flag_update_epoch=False, flag_update_batch=False)[source]

set the warmup strategy from objective scheduler set wheter the hyper-parameter scheduling happens per epoch or per batch

Parameters:
  • scheduler – The class name of the scheduler, the object corresponding to

  • model (this class name will be created inside) –

  • total_steps – number of steps to change the hyper-parameters

  • flag_update_epoch – if hyper-parameters should be changed per epoch

  • flag_update_batch – if hyper-parameters should be changed per batch

tr_epoch(epoch)[source]

update hyper-parameters only per epoch

domainlab.algos.trainers.train_irm module

use random start to generate adversarial images

class domainlab.algos.trainers.train_irm.TrainerIRM(successor_node=None, extend=None)[source]

Bases: TrainerBasic

IRMv1 split a minibatch into half, and use an unbiased estimate of the squared gradient norm via inner product $$delta_{w|w=1} ell(wdot Phi(X^{e, i}), Y^{e, i})$$ of dimension dim(Grad) with $$delta_{w|w=1} ell(wdot Phi(X^{e, j}), Y^{e, j})$$ of dimension dim(Grad) For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.”

tr_epoch(epoch)[source]
Parameters:

epoch

domainlab.algos.trainers.train_matchdg module

trainer matchdg

class domainlab.algos.trainers.train_matchdg.TrainerMatchDG(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

Contrastive Learning

before_tr()[source]

override abstract method

dset_decoration_args_algo(args, ddset)[source]

decorate dataset to get extra entries in load item, for instance, jigen need permutation index this parent class function delegate decoration to its decoratee

init_business(model, task, observer, device, aconf, flag_accept=True, flag_erm=False)[source]

initialize member objects

mk_match_tensor(epoch)[source]

initialize or update match tensor

tr_batch(epoch, batch_idx, x_e, y_e, d_e, others=None)[source]

update network for each batch

tr_epoch(epoch)[source]

# data in one batch comes from two sources: one part from loader, # the other part from match tensor

domainlab.algos.trainers.train_matchdg.match_tensor_reshape(batch_tensor_ref_domain2each)[source]

# original dimension is (ref_domain, domain, (channel, img_h, img_w)) # use a function so it is easier to accomodate other data mode (not image)

domainlab.algos.trainers.train_miro module

author: Kakao Brain. # https://arxiv.org/pdf/2203.10789#page=3.77 # [aut] xudong, alexej

class domainlab.algos.trainers.train_miro.TrainerMiro(successor_node=None, extend=None)[source]

Bases: TrainerBasic

Mutual-Information Regularization with Oracle

before_tr()[source]

check the performance of randomly initialized weight

domainlab.algos.trainers.train_miro_model_wraper module

https://arxiv.org/pdf/2203.10789#page=3.77

class domainlab.algos.trainers.train_miro_model_wraper.TrainerMiroModelWraper[source]

Bases: object

Mutual-Information Regularization with Oracle

accept(guest_model, name_feat_layers2extract=None)[source]
cal_feat_layers_ref_model(tensor_x, tensor_y, tensor_d, others=None)[source]
clear_features()[source]
extract_intermediate_features(tensor_x, tensor_y, tensor_d, others=None)[source]

extract features for each layer of the neural network

get_shapes(input_shape)[source]
hook(module, input, output)[source]
hook_ref(module, input, output)[source]
register_feature_storage_hook(feat_layers=None)[source]

domainlab.algos.trainers.train_miro_utils module

Laplace approximation for Mutual Information estimation

class domainlab.algos.trainers.train_miro_utils.MeanEncoder(inter_layer_feat_shape)[source]

Bases: Module

Identity function

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class domainlab.algos.trainers.train_miro_utils.VarianceEncoder(inter_layer_feat_shape, init=0.1, eps=1e-05)[source]

Bases: Module

Bias-only model with diagonal covariance

forward(feat_layer_tensor_batch)[source]

train batch(population) level variance

domainlab.algos.trainers.train_mldg module

Meta Learning Domain Generalization

class domainlab.algos.trainers.train_mldg.TrainerMLDG(successor_node=None, extend=None)[source]

Bases: AbstractTrainer

basic trainer

before_tr()[source]

check the performance of randomly initialized weight

prepare_ziped_loader()[source]

create virtual source and target domain

tr_epoch(epoch)[source]
Parameters:

epoch

domainlab.algos.trainers.zoo_trainer module

select trainer

class domainlab.algos.trainers.zoo_trainer.TrainerChainNodeGetter(str_trainer)[source]

Bases: object

Chain of Responsibility: node is named in pattern Trainer[XXX] where the string after ‘Trainer’ is the name to be passed to args.trainer.

Module contents