Domain adaptation's main objective is to adapt the model trained on the source dataset in which the label is available to perform decently on the target dataset, which has a pertinent distribution yet the label is not already on hand. In this project, the pretrained RegNetY_400MF is leveraged as the model undergoing the adaptation procedure. The procedure is conducted with Domain-Adversarial Training of Neural Networks or DANN. Succinctly, DANN works by adversarially training the appointed model on the source dataset along with the target dataset. DANN uses an extra network as the domain classifier (the critic or discriminator) and applies a gradient reversal layer to the output of the feature extractor. Thus, the losses accounted for this scheme are the classification head loss (the source dataset) and the domain loss (the source dataset and the target dataset). Here, the source dataset is MNIST and the target dataset is SVHN. On MNIST, various data augmentations (geometric and photometric) are utilized on the fly during training. To monitor the adaptation performance, the testing set of SVHN is designated as the validation and testing set.
Study the adaptation process by following the link to the notebook quenching your curiosity.
The model's performance on the target dataset:
Test Metric | Score |
---|---|
Loss | 3.138 |
Accuracy | 44.79% |
Accuracy curves of the model on the source dataset (MNIST) and the target dataset (SVHN).
Loss curves of the model on the source dataset (MNIST) and the target dataset (SVHN).
The collated image below visually reports the prediction results on the target dataset.
Some results on the SVHN dataset as the target dataset.
- Domain-Adversarial Training of Neural Networks
- Unsupervised Domain Adaptation by Backpropagation
- DANN
- Designing Network Design Spaces
- Reading Digits in Natural Images with Unsupervised Feature Learning
- TorchVision's SVHN
- Gradient-based learning applied to document recognition
- TorchVision's MNIST
- Semi-supervision and domain adaptation with AdaMatch
- PyTorch Lightning