Model HDUVA¶
HDUVA: HIERARCHICAL VARIATIONAL AUTO-ENCODING FOR UNSUPERVISED DOMAIN GENERALIZATION¶
HDUVA builds on a generative approach within the framework of variational autoencoders to facilitate generalization to new domains without supervision. HDUVA learns representations that disentangle domain-specific information from class-label specific information even in complex settings where domain structure is not observed during training.
Model Overview¶
More specifically, HDUVA is based on three latent variables that are used to model distinct sources of variation and are denoted as \(z_y\), \(z_d\) and \(z_x\). \(z_y\) represents class specific information, \(z_d\) represents domain specific information and \(z_x\) models residual variance of the input. We introduce an additional hierarchical level and use a continuous latent representation s to model (potentially unobserved) domain structure. This means that we can encourage disentanglement of the latent variables through conditional priors without the need of conditioning on a one-hot-encoded, observed domain label. The model along with its parameters and hyperparameters is shown in Figure 1:
Note that as part of the model a latent representation of \(X\) is concatentated with \(s\) and \(z_d\) (dashed arrows), requiring respecive encoder networks.
Evidence lower bound and overall loss¶
The ELBO of the model can be decomposed into 4 different terms:
Likelihood: \(E_{q(z_d, s|x), q(z_x|x), q(z_y|x)}\log p_{\theta}(x|s, z_d, z_x, z_y)\)
KL divergence weighted as in the Beta-VAE: \(-\beta_x KL(q_{\phi_x}(z_x|x)||p_{\theta_x}(z_x)) - \beta_y KL(q_{\phi_y}(z_y|x)||p_{\theta_y}(z_y|y))\)
Hierarchical KL loss (domain term): \(- \beta_d E_{q_{\phi_s}(s|x), q_{\phi_d}(z_d|x, s)} \log \frac{q_{\phi_d}(z_d|x, s)}{p_{\theta_d}(z_d|s)}\)
Hierarchical KL loss (topic term): \(-\beta_t E_{q_{\phi_s}(s|x)}KL(q_{\phi_s}(s|x)||p_{\theta_s}(s|\alpha))\)
In addition, we construct the overall loss by adding an auxiliary classsifier, by adding an additional term to the ELBO loss, weighted with \(\gamma_y\):
Hyperparameters loss function¶
For fitting the model, we need to specify the 4 \(\beta\)-weights related to the the different terms of the ELBO ( \(\beta_x\) , \(\beta_y\), \(\beta_d\), \(\beta_t\)) as well as \(\gamma_y\).
Model hyperparameters¶
In addition to these hyperparameters, the following model parameters can be specified:
zd_dim
: size of latent space for domain-specific informationzx_dim
: size of latent space for residual variancezy_dim
: size of latent space for class-specific informationtopic_dim
: size of dirichlet distribution for topics \(s\)
The user need to specify at least two neural networks for the encoder part via
npath_encoder_x2topic_h
: the python file path of a neural network that maps the image (or other modal of data to a one dimensional (topic_dim
) hidden representation serving as input to Dirichlet encoder:X->h_t(X)->alpha(h_t(X))
wherealpha
is the neural network to map a 1-d hidden layer to dirichlet concentration parameter.npath_encoder_sandwich_x2h4zd
: the python file path of a neural network that maps the image to a hidden representation (same size astopic_dim
), which will be used to infere the posterior distribution ofz_d
:topic(X), X -> [topic(X), h_d(X)] -> zd_mean, zd_scale
Alternatively, one could use an existing neural network in DomainLab using nname
instead of npath
:
nname_encoder_x2topic_h
nname_encoder_sandwich_x2h4zd
Hyperparameter for warmup¶
Finally, the number of epochs for hyper-parameter warm-up can be specified via the argument warmup
.
Please cite our paper if you find it useful!
@inproceedings{sun2021hierarchical,
title={Hierarchical Variational Auto-Encoding for Unsupervised Domain Generalization},
author={Sun, Xudong and Buettner, Florian},
booktitle={ICLR 2021 RobustML workshop, https://arxiv.org/pdf/2101.09436.pdf},
year={2021}
}
Examples¶
hduva use custom net for sandwich encoder¶
python main_out.py --te_d=caltech --bs=2 --task=mini_vlcs --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --npath_encoder_sandwich_x2h4zd=examples/nets/resnet.py
hduva use custom net for topic encoder¶
python main_out.py --te_d=caltech --bs=2 --task=mini_vlcs --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --npath_encoder_x2topic_h=examples/nets/resnet.py --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
hduva use custom net for classification encoder¶
python main_out.py --te_d=caltech --bs=2 --task=mini_vlcs --model=hduva --npath=examples/nets/resnet.py --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
use hduva on color mnist, train on 2 domains¶
python main_out.py --tr_d 0 1 2 --te_d 3 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
hduva is domain-unsupervised, so it works also with a single domain¶
python main_out.py --tr_d 0 --te_d 3 4 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
hduva with implemented neural network¶
python main_out.py --te_d=caltech --bs=2 --task=mini_vlcs --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
hduva use alex net¶
python main_out.py --te_d=caltech --bs=2 --task=mini_vlcs --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=alexnet