Source code for domainlab.algos.trainers.train_hyper_scheduler

"""
update hyper-parameters during training
"""
from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupLinear
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.utils.logger import Logger


[docs] class TrainerHyperScheduler(TrainerBasic): """ TrainerHyperScheduler """
[docs] def set_scheduler( self, scheduler, total_steps, flag_update_epoch=False, flag_update_batch=False ): """ set the warmup strategy from objective scheduler set wheter the hyper-parameter scheduling happens per epoch or per batch Args: scheduler: The class name of the scheduler, the object corresponding to this class name will be created inside model total_steps: number of steps to change the hyper-parameters flag_update_epoch: if hyper-parameters should be changed per epoch flag_update_batch: if hyper-parameters should be changed per batch """ self.hyper_scheduler = self.model.hyper_init(scheduler) # let model register its hyper-parameters to the scheduler self.flag_update_hyper_per_epoch = flag_update_epoch self.flag_update_hyper_per_batch = flag_update_batch self.hyper_scheduler.set_steps(total_steps=total_steps)
[docs] def before_batch(self, epoch, ind_batch): """ if hyper-parameters should be updated per batch, then step should be set to epoch*self.num_batches + ind_batch """ if self.flag_update_hyper_per_batch: self.model.hyper_update( epoch * self.num_batches + ind_batch, self.hyper_scheduler ) return super().before_batch(epoch, ind_batch)
[docs] def before_tr(self): if self.hyper_scheduler is None: logger = Logger.get_logger() logger.warning( "hyper-parameter scheduler not set," "going to use default Warmpup and epoch update" ) self.set_scheduler( HyperSchedulerWarmupLinear, total_steps=self.aconf.warmup, flag_update_epoch=True, ) super().before_tr()
[docs] def tr_epoch(self, epoch): """ update hyper-parameters only per epoch """ if self.flag_update_hyper_per_epoch: self.model.hyper_update(epoch, self.hyper_scheduler) return super().tr_epoch(epoch)