Source code for domainlab.arg_parser

"""
Command line arguments
"""
import argparse
import warnings

import yaml

from domainlab.algos.trainers.args_dial import add_args2parser_dial
from domainlab.algos.trainers.compos.matchdg_args import add_args2parser_matchdg
from domainlab.algos.trainers.args_miro import add_args2parser_miro
from domainlab.models.args_jigen import add_args2parser_jigen
from domainlab.models.args_vae import add_args2parser_vae
from domainlab.utils.logger import Logger

[docs] class ParseValuesOrKeyValuePairs(argparse.Action): """Class used for arg parsing where values are provided in a key value format""" def __call__(self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: str, option_string: str = None): """ Handle parsing of key value pairs, or a single value instead Args: parser (argparse.ArgumentParser): The ArgumentParser object. namespace (argparse.Namespace): The namespace object to store parsed values. values (str): The string containing key=value pairs or a single float value. option_string (str, optional): The option string that triggered this action (unused). Raises: ValueError: If the values cannot be parsed to float. """ if "=" in values: my_dict = {} for kv in values.split(","): k, v = kv.split("=") try: my_dict[k.strip()] = float(v.strip()) except ValueError: raise ValueError(f"Invalid value in key-value pair: '{kv}', must be float") setattr(namespace, self.dest, my_dict) else: try: setattr(namespace, self.dest, float(values)) except ValueError: raise ValueError(f"Invalid value for {self.dest}: '{values}', must be float")
[docs] def mk_parser_main(): """ Args for command line definition """ parser = argparse.ArgumentParser(description="DomainLab") parser.add_argument( "-c", "--config", default=None, help="load YAML configuration", dest="config_file", type=argparse.FileType(mode="r"), ) parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") parser.add_argument( "--gamma_reg", default=0.1, help="weight of regularization loss in the form of $$\ell(\cdot) + \mu \times R(\cdot)$$ \ can specify per model as 'default=3.0, dann=1.0,jigen=2.0', where default refer to gamma for trainer \ note diva is implemented $$\ell(\cdot) + \mu \times R(\cdot)$$ \ so diva does not have gamma_reg", action=ParseValuesOrKeyValuePairs ) parser.add_argument("--es", type=int, default=1, help="early stop steps") parser.add_argument("--seed", type=int, default=0, help="random seed (default: 0)") parser.add_argument( "--nocu", action="store_true", default=False, help="disables CUDA" ) parser.add_argument( "--device", type=str, default=None, help="device name default None" ) parser.add_argument( "--gen", action="store_true", default=False, help="save generated images" ) parser.add_argument( "--keep_model", action="store_true", default=False, help="do not delete model at the end of training", ) parser.add_argument("--epos", default=2, type=int, help="maximum number of epochs") parser.add_argument( "--epos_min", default=0, type=int, help="maximum number of epochs" ) parser.add_argument( "--epo_te", default=1, type=int, help="test performance per {} epochs" ) parser.add_argument( "-w", "--warmup", type=int, default=100, help="number of epochs for hyper-parameter warm-up. \ Set to 0 to turn warmup off.", ) parser.add_argument( "-nb4ratio", "--nb4reg_over_task_ratio", type=int, default=1, help="number of batches for estimating reg loss over task loss ratio \ default 1", ) parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--dmem", action="store_true", default=False) parser.add_argument( "--no_dump", action="store_true", default=False, help="suppress saving the confusion matrix", ) parser.add_argument( "--trainer", type=str, default=None, help="specify which trainer to use" ) parser.add_argument( "--out", type=str, default="zoutput", help="absolute directory to store outputs" ) parser.add_argument( "--dpath", type=str, default="zdpath", help="path for storing downloaded dataset", ) parser.add_argument( "--tpath", type=str, default=None, help="path for custom task, should implement \ get_task function", ) parser.add_argument( "--npath", type=str, default=None, help="path of custom neural network for feature \ extraction", ) parser.add_argument( "--npath_dom", type=str, default=None, help="path of custom neural network for feature \ extraction", ) parser.add_argument( "--npath_argna2val", action="append", help="specify new arguments and their value \ e.g. '--npath_argna2val my_custom_arg_na \ --npath_argna2val xx/yy/zz.py', additional \ pairs can be appended", ) parser.add_argument( "--nname_argna2val", action="append", help="specify new arguments and their values \ e.g. '--nname_argna2val my_custom_network_arg_na \ --nname_argna2val alexnet', additional pairs \ can be appended", ) parser.add_argument( "--nname", type=str, default=None, help="name of custom neural network for feature \ extraction of classification", ) parser.add_argument( "--nname_dom", type=str, default=None, help="name of custom neural network for feature \ extraction of domain", ) parser.add_argument( "--apath", type=str, default=None, help="path for custom AlgorithmBuilder" ) parser.add_argument( "--exptag", type=str, default="exptag", help="tag as prefix of result aggregation file name \ e.g. git hash for reproducibility", ) parser.add_argument( "--aggtag", type=str, default="aggtag", help="tag in each line of result aggregation file \ e.g., to specify potential different configurations", ) parser.add_argument( "--agg_partial_bm", type=str, default=None, dest="bm_dir", help="Aggregates and plots partial data of a snakemake \ benchmark. Requires the benchmark config file. \ Other arguments will be ignored.", ) parser.add_argument( "--gen_plots", type=str, default=None, dest="plot_data", help="plots the data of a snakemake benchmark. " "Requires the results.csv file" "and an output file (specify by --outp_file," "default is zoutput/benchmarks/shell_benchmark). " "Other arguments will be ignored.", ) parser.add_argument( "--outp_dir", type=str, default="zoutput/benchmarks/shell_benchmark", dest="outp_dir", help="outpus file for the plots when creating them" "using --gen_plots. " "Default is zoutput/benchmarks/shell_benchmark", ) parser.add_argument( "--param_idx", type=bool, default=True, dest="param_idx", help="True: parameter index is used in the " "pots generated with --gen_plots." "False: parameter name is used." "Default is True.", ) parser.add_argument( "--msel", choices=["val", "loss_tr"], default="val", help="model selection for early stop: val, loss_tr, recon, the \ elbo and recon only make sense for vae models,\ will be ignored by other methods", ) parser.add_argument( "--model", metavar="an", type=str, default=None, help="algorithm name" ) parser.add_argument( "--acon", metavar="ac", type=str, default=None, help="algorithm configuration name, (default None)", ) parser.add_argument("--task", metavar="ta", type=str, help="task name") parser.add_argument( "--val_threshold", type=float, default=None, help="Accuracy threshold before early stopping can be applied" ) arg_group_task = parser.add_argument_group("task args") arg_group_task.add_argument( "--bs", type=int, default=100, help="loader batch size for mixed domains" ) arg_group_task.add_argument( "--split", type=float, default=0, help="proportion of training, a value between \ 0 and 1, 0 means no train-validation split", ) arg_group_task.add_argument( "--te_d", nargs="*", default=None, help="test domain names separated by single space, \ will be parsed to be list of strings", ) arg_group_task.add_argument( "--tr_d", nargs="*", default=None, help="training domain names separated by \ single space, will be parsed to be list of \ strings; if not provided then all available \ domains that are not assigned to \ the test set will be used as training domains", ) arg_group_task.add_argument( "--san_check", action="store_true", default=False, help="save images from the dataset as a sanity check", ) arg_group_task.add_argument( "--san_num", type=int, default=8, help="number of images to be dumped for the sanity check", ) arg_group_task.add_argument( "--loglevel", type=str, default="DEBUG", help="sets the loglevel of the logger" ) arg_group_task.add_argument( "--shuffling_off", action="store_true", default=False, help="disable shuffling of the training dataloader for the dataset" ) # args for variational auto encoder arg_group_vae = parser.add_argument_group("vae") arg_group_vae = add_args2parser_vae(arg_group_vae) arg_group_matchdg = parser.add_argument_group("matchdg") arg_group_matchdg = add_args2parser_matchdg(arg_group_matchdg) arg_group_miro = parser.add_argument_group("miro") arg_group_miro = add_args2parser_miro(arg_group_miro) arg_group_jigen = parser.add_argument_group("jigen") arg_group_jigen = add_args2parser_jigen(arg_group_jigen) args_group_dial = parser.add_argument_group("dial") args_group_dial = add_args2parser_dial(args_group_dial) return parser
[docs] def apply_dict_to_args(args, data: dict, extend=False): """ Tries to apply the data to the args dict of DomainLab. Unknown keys are silently ignored as long as extend is not set. """ arg_dict = args.__dict__ for key, value in data.items(): if (key in arg_dict) or extend: if isinstance(value, list): cur_val = arg_dict.get(key, None) if not isinstance(cur_val, list): if cur_val is not None: raise RuntimeError( f"input dictionary value is list, \ however, in DomainLab args, we have {cur_val}, \ going to overrite to list" ) arg_dict[key] = [] # if args_dict[key] is None, cast it into a list # domainlab will take care of it if this argument can not be a list arg_dict[key].extend(value) # args_dict[key] is already a list # keep existing values for the list arg_dct[key] else: # over-write existing value arg_dict[key] = value else: raise ValueError("Unsupported key: ", key)
[docs] def parse_cmd_args(): """ get args from command line """ parser = mk_parser_main() args = parser.parse_args() logger = Logger.get_logger(logger_name="main_out_logger", loglevel=args.loglevel) if args.config_file: data = yaml.safe_load(args.config_file) delattr(args, "config_file") apply_dict_to_args(args, data) if args.acon is None and args.bm_dir is None: logger.warn("\n\n") logger.warn("no algorithm conf specified, going to use default") logger.warn("\n\n") warnings.warn("no algorithm conf specified, going to use default") return args