Source code for domainlab.algos.trainers.zoo_trainer
"""
select trainer
"""
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.algos.trainers.train_ema import TrainerMA
from domainlab.algos.trainers.train_dial import TrainerDIAL
from domainlab.algos.trainers.train_hyper_scheduler \
    import TrainerHyperScheduler
from domainlab.algos.trainers.train_matchdg import TrainerMatchDG
from domainlab.algos.trainers.train_mldg import TrainerMLDG
from domainlab.algos.trainers.train_fishr import TrainerFishr
from domainlab.algos.trainers.train_irm import TrainerIRM
from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL
from domainlab.algos.trainers.train_coral import TrainerCoral
from domainlab.algos.trainers.train_miro import TrainerMiro
[docs]
class TrainerChainNodeGetter(object):
    """
    Chain of Responsibility: node is named in pattern Trainer[XXX] where the string
    after 'Trainer' is the name to be passed to args.trainer.
    """
    def __init__(self, str_trainer):
        """__init__.
        :param args: command line arguments
        """
        self._list_str_trainer = None
        if str_trainer is not None:
            self._list_str_trainer = str_trainer.split("_")
            self.request = self._list_str_trainer.pop(0)
        else:
            self.request = str_trainer
    def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
        """
        1. construct the chain, filter out responsible node,
        create heavy-weight business object
        2. hard code seems to be the best solution
        """
        if lst_candidates is not None and self.request not in lst_candidates:
            raise RuntimeError(
                f"desired {self.request} is not supported \
                               among {lst_candidates}"
            )
        if default is not None and self.request is None:
            self.request = default
        if lst_excludes is not None and self.request in lst_excludes:
            raise RuntimeError(
                f"desired {self.request} is not supported among {lst_excludes}"
            )
        chain = TrainerBasic(None)
        chain = TrainerMA(chain)
        chain = TrainerDIAL(chain)
        chain = TrainerMatchDG(chain)
        chain = TrainerMLDG(chain)
        chain = TrainerFishr(chain)
        chain = TrainerIRM(chain)
        chain = TrainerHyperScheduler(chain)
        chain = TrainerCausalIRL(chain)
        chain = TrainerCoral(chain)
        chain = TrainerMiro(chain)
        node = chain.handle(self.request)
        head = node
        while self._list_str_trainer:
            self.request = self._list_str_trainer.pop(0)
            node2decorate = self.__call__(lst_candidates, default, lst_excludes)
            head.extend(node2decorate)
            head = node2decorate
        return node