Skip to content

tmllab/2022_NeurIPS_RSA

Repository files navigation

RSA: Reducing Semantic Shift from Aggressive Augmentations for Self-supervised Learning

Paper

Official implementation of RSA: Reducing Semantic Shift from Aggressive Augmentations for Self-supervised Learning (NeurIPS 2022).

Most recent self-supervised learning methods learn visual representation by contrasting different augmented views of images. Compared with supervised learning, more aggressive augmentations have been introduced to further improve the diversity of training pairs. However, aggressive augmentations may distort images' structures leading to a severe semantic shift problem that augmented views of the same image may not share the same semantics, thus degrading the transfer performance. To address this problem, we propose a new SSL paradigm, which counteracts the impact of semantic shift by balancing the role of weak and aggressively augmented pairs. Specifically, semantically inconsistent pairs are of minority, and we treat them as noisy pairs. Note that deep neural networks (DNNs) have a crucial memorization effect that DNNs tend to first memorize clean (majority) examples before overfitting to noisy (minority) examples. Therefore, we set a relatively large weight for aggressively augmented data pairs at the early learning stage. With the training going on, the model begins to overfit noisy pairs. Accordingly, we gradually reduce the weights of aggressively augmented pairs. In doing so, our method can better embrace aggressive augmentations and neutralize the semantic shift problem. Experiments show that our model achieves 73.1% top-1 accuracy on ImageNet-1K with ResNet-50 for 200 epochs, which is a 2.5% improvement over BYOL. Moreover, experiments also demonstrate that the learned representations can transfer well for various downstream tasks.

The illustration of our proposed method (RSA). We utilize an asymmetric-style framework, including an online network and a target network. The online network is optimized by gradients, and the target network is updated with the exponential moving average strategy. We first adopt the weak augmentation to generate two views, then adopt the aggressive augmentations to further generate another two views. Subsequently, we make aggressive-augmented views to keep consistent with their corresponding weak- and aggressive-augmented views in the embedding space. On the right of the image, we compare RSA with classical SSL methods. RSA forces learned representations to a balance between weak- and aggressive-augmented views.

Requirements

  • This codebase is written for python3 and 'pytorch'.
  • To install necessary python packages, run pip install -r requirements.txt.

Experiments

Data

  • Please download and place all datasets into the data directory.

Training

To train RSA on CIFAR-100

python train_single.py --dataset cifar100 --beta 0.3

To train RSA on STL-10

python train_single.py --dataset stl10

To train RSA on ImageNet-100

python train_multi.py --dataset ImageNet-100 --data_root data/ImageNet-100/

To train RSA on ImageNet-1K

python train_multi.py --dataset ImageNet --lr 0.6 --wd 1e-6 --batch-size 2048 --warmup-epochs 10 --data_root data/ImageNet/

License and Contributing

  • This README is formatted based on paperswithcode.
  • Feel free to post issues via Github.

Reference

If you find the code useful in your research, please consider citing our paper:

@inproceedings{bai2022rsa,
 author = {Bai, Yingbin and Yang, Erkun and Wang, Zhaoqing and Du, Yuxuan and Han, Bo and Deng, Cheng and Wang, Dadong and Liu, Tongliang},
 booktitle = {NeurIPS},
 title = {RSA: Reducing Semantic Shift from Aggressive Augmentations for Self-supervised Learning},
 volume = {35},
 pages = {21128--21141},
 year = {2022}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages