Source code for domainlab.algos.trainers.train_ema
"""
simple exponential moving average of each layers, after each epoch,
trainer=ma_trainer2_trainer3
always set ma to be outer most
Paper:
Ensemble of Averages: Improving Model Selection and
Boosting Performance in Domain Generalization
Devansh Arpit, Huan Wang, Yingbo Zhou, Caiming Xiong
Salesforce Research, USA
"""
import torch
from domainlab.algos.trainers.train_basic import TrainerBasic
[docs]
class TrainerMA(TrainerBasic):
"""
initializer of this class goes to one block/section in the abstract class
initializer, otherwise it will break the class inheritance.
"""
[docs]
def move_average(self, dict_data, epoch):
"""
for each epoch, convex combine the weights for each layer
Paper:
Ensemble of Averages: Improving Model Selection and
Boosting Performance in Domain Generalization
Devansh Arpit, Huan Wang, Yingbo Zhou, Caiming Xiong
Salesforce Research, USA
"""
self.ma_weight_previous_model_params = epoch / (epoch + 1)
# 1/2, 2/3, 3/4, 4/5,
# weight on previous model converges to 1 as training goes on
dict_return_ema_para_curr_iter = {}
for key, data in dict_data.items():
# data = data.view(1, -1) # make it rank 1 tensor (a.k.a. vector)
if self._ma_iter == 0:
previous_data = torch.zeros_like(data)
local_data_convex = data
else:
previous_data = self._dict_previous_para_persist[key]
local_data_convex = \
self.ma_weight_previous_model_params * previous_data + \
(1 - self.ma_weight_previous_model_params) * data
# correction by 1/(1 - self.ma_weight_previous_model_params)
# so that the gradients amplitude backpropagated in data is
# independent of self.ma_weight_previous_model_params
# We did not do this because 1-rho will be almost zero as
# epochs goes on, which will expand the neural network weights
# to overflow
# dict_return_ema_para_curr_iter[key] = \
# local_data_convex / (1 - self.ma_weight_previous_model_params)
dict_return_ema_para_curr_iter[key] = local_data_convex
self._dict_previous_para_persist[key] = \
local_data_convex.clone().detach() # used as previous data
self._ma_iter += 1
return dict_return_ema_para_curr_iter
[docs]
def after_epoch(self, epoch):
torch_model = self.get_model()
dict_para = torch_model.state_dict() # only for trainable parameters
new_dict_para = self.move_average(dict_para, epoch)
# without deepcopy, this seems to work
torch_model.load_state_dict(new_dict_para)
super().after_epoch(epoch)