Trainer DIAL¶
Domain Invariant Adversarial Learning¶
The algorithm introduced in https://arxiv.org/pdf/2104.00322.pdf uses adversarial learning to tackle the task of domain generalization. Therefore, the source domain is the natural dataset, while the target domain is generated using adversarial attack on the source domain.
generating the adversarial domain¶
The generation of adversary images is demonstrated in figure 1. The task is to find an adversary image \(x'\) to the natural image \(x\) with \(||x- x'||\) small, such that the output of a classification network \(\phi\) fulfills \(||\phi(x) - \phi(x')||\) big. In the example in figure 1 you can for example see, that the difference between the left and the right image of the panda is unobservable, but the classifier does still classify them differently.
In Domainlab the adversary images are created starting from a random perturbation of the natural image \(x'_0 = x + \sigma \tilde{x}~\), \(\tilde{x} \sim \mathcal{N}(0, 1)\) and using \(n\) steps in a gradient descend with step size \(\tau\) to maximize \(||\phi(x) - \phi(x')||\). In general machine learning, the generation of adversary images is used during the training process to make networks more robust to adversarial attacks.
network structure¶
The network consists of three parts. At first a feature extractor, which extracts the main characteristics of the images. This features are then used as the input to a label classifier and a domain classifier. During training the network is optimized to a have low error on the classification task, while ensuring that the internal representation (output of the feature extractor) cannot discriminate between the natural and adversarial domain. This goal can be archived by using a special loss function in combination with a gradient reversal layer.
loss function and gradient reversal layer¶
The loss function for in the DomainLab package is different to the one described in the paper. It consists of the standard cross entropy loss between the predicted label probabilities and the actual label for the natural domain (\(CE_{nat}\)) and for the adversarial domain (\(CE_{adv}\)). The adversarial domain is weighted by the parameter \(\gamma_\text{reg}\).
This procedure yields to the following availability of hyperparameter:
--dial_steps_perturb
: how many gradient step to go to find an adversarial image (\(n\) from “generating the adversarial domain”)--dial_noise_scale
: variance of gaussian noise to inject on pure image (\(\sigma\) from “generating the adversarial domain”)--dial_lr
: learning rate to generate adversarial images (\(\tau\) from “generating the adversarial domain”)--dial_epsilon
: pixel wise threshold to perturb images--gamma_reg
: ? (\(\epsilon\) in the paper)--lr
: learning rate (\(\alpha\) in the paper)
Examples¶
python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=dial --nname=conv_bn_pool_2
python main_out.py --te_d=0 --task=mnistcolor10 --keep_model --model=erm --trainer=dial --nname=conv_bn_pool_2
Train DIVA model with DIAL trainer¶
python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=diva --nname=conv_bn_pool_2 --nname_dom=conv_bn_pool_2 --gamma_y=7e5 --gamma_d=1e5 --trainer=dial
Set hyper-parameters for trainer as well¶
python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=diva --nname=conv_bn_pool_2 --nname_dom=conv_bn_pool_2 --gamma_y=7e5 --gamma_d=1e5 --trainer=dial --dial_steps_perturb=1