Source code for domainlab.algos.trainers.zoo_trainer
"""
select trainer
"""
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.algos.trainers.train_ema import TrainerMA
from domainlab.algos.trainers.train_dial import TrainerDIAL
from domainlab.algos.trainers.train_hyper_scheduler \
import TrainerHyperScheduler
from domainlab.algos.trainers.train_matchdg import TrainerMatchDG
from domainlab.algos.trainers.train_mldg import TrainerMLDG
from domainlab.algos.trainers.train_fishr import TrainerFishr
from domainlab.algos.trainers.train_irm import TrainerIRM
from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL
from domainlab.algos.trainers.train_coral import TrainerCoral
from domainlab.algos.trainers.train_miro import TrainerMiro
[docs]
class TrainerChainNodeGetter(object):
"""
Chain of Responsibility: node is named in pattern Trainer[XXX] where the string
after 'Trainer' is the name to be passed to args.trainer.
"""
def __init__(self, str_trainer):
"""__init__.
:param args: command line arguments
"""
self._list_str_trainer = None
if str_trainer is not None:
self._list_str_trainer = str_trainer.split("_")
self.request = self._list_str_trainer.pop(0)
else:
self.request = str_trainer
def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
"""
1. construct the chain, filter out responsible node,
create heavy-weight business object
2. hard code seems to be the best solution
"""
if lst_candidates is not None and self.request not in lst_candidates:
raise RuntimeError(
f"desired {self.request} is not supported \
among {lst_candidates}"
)
if default is not None and self.request is None:
self.request = default
if lst_excludes is not None and self.request in lst_excludes:
raise RuntimeError(
f"desired {self.request} is not supported among {lst_excludes}"
)
chain = TrainerBasic(None)
chain = TrainerMA(chain)
chain = TrainerDIAL(chain)
chain = TrainerMatchDG(chain)
chain = TrainerMLDG(chain)
chain = TrainerFishr(chain)
chain = TrainerIRM(chain)
chain = TrainerHyperScheduler(chain)
chain = TrainerCausalIRL(chain)
chain = TrainerCoral(chain)
chain = TrainerMiro(chain)
node = chain.handle(self.request)
head = node
while self._list_str_trainer:
self.request = self._list_str_trainer.pop(0)
node2decorate = self.__call__(lst_candidates, default, lst_excludes)
head.extend(node2decorate)
head = node2decorate
return node