This repository is the official implementation for the NeurIPS 2021 paper Domain Domain Invariant Representation Learning with Domain Density Transformations.
Please consider citing our paper as
@article{nguyen2021domain,
title={Domain invariant representation learning with domain density transformations},
author={Nguyen, A. Tuan and Tran, Toan and Gal, Yarin and Baydin, Atilim Gunes},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
Code for StarGan is modified from https://github.com/yunjey/stargan
Code for RotatedMnist DataLoader is modified from https://github.com/AMLab-Amsterdam/DIVA
Code for other Dataset DataLoader is modified from https://github.com/facebookresearch/DomainBed
python3, pytorch 1.7.0 or higher, torchvision 0.8.0 or higher
- To run the experiment for Rotated MNIST: For example, target domain 0 and seed 0
cd domain_gen_rotatedmnist
CUDA_VISIBLE_DEVICES=0 python train_stargan.py --target_domain 0 # To run the StarGAN model, although we already provide the checkpoint so you might skip this
CUDA_VISIBLE_DEVICES=0 python -u train.py --model=dirt --seed=0 --epochs=500 --target_domain=0
- To run the experiment for PACS: For example, for PACS with ResNet, target domain 0 and seed 0
# Change the --data_dir flag to your data directory
cd domain_gen
CUDA_VISIBLE_DEVICES=0 python train_stargan.py --dataset PACS --data_dir ../data/ --target_domain 0 # To run the StarGAN model, we provided checkpoint for PACS
CUDA_VISIBLE_DEVICES=0 python -u main.py --dataset PACS --data_dir ../data --model=dirt --base resnet18 --seed=0 --target_domain=0