Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add ORL #668

Open
wants to merge 18 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ __pycache__/
*.py[cod]
*$py.class
**/*.pyc

*.out
# C extensions
*.so

Expand Down
57 changes: 57 additions & 0 deletions configs/selfsup/_base_/datasets/coco_orl_stage1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import copy

# dataset settings
dataset_type = 'mmdet.CocoDataset'
# data_root = 'data/coco/'
data_root = '../data/coco/'
file_client_args = dict(backend='disk')
view_pipeline = [
dict(
type='RandomResizedCrop',
size=224,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1),
dict(type='RandomSolarize', prob=0)
]
view_pipeline1 = copy.deepcopy(view_pipeline)
view_pipeline2 = copy.deepcopy(view_pipeline)
view_pipeline2[4]['prob'] = 0.1 # gaussian blur
view_pipeline2[5]['prob'] = 0.2 # solarization
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='MultiView',
num_views=[1, 1],
transforms=[view_pipeline1, view_pipeline2]),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]
train_dataloader = dict(
batch_size=64,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=train_pipeline))
89 changes: 89 additions & 0 deletions configs/selfsup/_base_/datasets/coco_orl_stage3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import copy

# dataset settings
dataset_type = 'ORLDataset'
meta_json = '../data/coco/meta/train2017_10nn_instance_correspondence.json'
data_train_root = '../data/coco/train2017'
# file_client_args = dict(backend='disk')
view_pipeline = [
dict(
type='RandomResizedCrop',
size=224,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989),
color_format='rgb'),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1.),
dict(type='RandomSolarize', prob=0)
]

view_patch_pipeline = [
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989),
color_format='rgb'),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1.),
dict(type='RandomSolarize', prob=0)
]
view_pipeline1 = copy.deepcopy(view_pipeline)
view_pipeline2 = copy.deepcopy(view_pipeline)
view_patch_pipeline1 = copy.deepcopy(view_patch_pipeline)
view_patch_pipeline2 = copy.deepcopy(view_patch_pipeline)
view_pipeline2[4]['prob'] = 0.1 # gaussian blur
view_pipeline2[5]['prob'] = 0.2 # solarization
view_patch_pipeline1[3]['prob'] = 0.1 # gaussian blur
view_patch_pipeline2[4]['prob'] = 0.2 # solarization

train_dataloader = dict(
batch_size=64,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
root=data_train_root,
json_file=meta_json,
topk_knn_image=10,
img_pipeline1=view_pipeline1,
img_pipeline2=view_pipeline2,
patch_pipeline1=view_patch_pipeline1,
patch_pipeline2=view_patch_pipeline2,
patch_size=96,
interpolation=2,
shift=(-0.5, 0.5),
scale=(0.5, 2.),
ratio=(0.5, 2.),
iou_thr=0.5,
attempt_num=200,
))
41 changes: 41 additions & 0 deletions configs/selfsup/_base_/models/orl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# model settings
model = dict(
type='ORL',
base_momentum=0.99,
data_preprocessor=dict(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
# mean=(123.675, 116.28, 103.53),
# std=(58.395, 57.12, 57.375),
bgr_to_rgb=False),
pretrained=None,
global_loss_weight=1.,
loc_intra_loss_weight=1.,
loc_inter_loss_weight=1.,
backbone=dict(
type='ResNet',
depth=50,
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='SyncBN')),
neck=dict(
type='NonLinearNeck',
in_channels=2048,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_bias=False,
with_last_bn=False,
with_avg_pool=True),
head=dict(
type='LatentPredictHead',
predictor=dict(
type='NonLinearNeck',
in_channels=256,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_bias=False,
with_last_bn=False,
with_avg_pool=False),
loss=dict(type='CosineSimilarityLoss')))
141 changes: 141 additions & 0 deletions configs/selfsup/orl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# ORL

> [Unsupervised Object-Level Representation Learning
> from Scene Images
> ](https://arxiv.org/abs/2106.11952)

<!-- [ALGORITHM] -->

## Abstract

Contrastive self-supervised learning has largely narrowed the gap to supervised pre-training on ImageNet. However, its success highly relies on the object-centric priors of ImageNet, i.e., different augmented views of the same image correspond to the same object. Such a heavily curated constraint becomes immediately infeasible when pre-trained on more complex scene images with many objects. To overcome this limitation, we introduce Object-level Representation Learning (ORL), a new self-supervised learning framework towards scene images. Our key insight is to leverage image-level self-supervised pre-training as the prior to discover object-level semantic correspondence, thus realizing object-level representation learning from scene images. Extensive experiments on COCO show that ORL significantly improves the performance of self-supervised learning on scene images, even surpassing supervised ImageNet pre-training on several downstream tasks. Furthermore, ORL improves the downstream performance when more unlabeled scene images are available, demonstrating its great potential of harnessing unlabeled data in the wild. We hope our approach can motivate future research on more general-purpose unsupervised representation learning from scene data.

<div align="center">
<img src="https://github.com/Jiahao000/ORL/raw/2ad64f7389d20cb1d955792aabbe806a7097e6fb/highlights.png" width="90%" />
</div>

## Usage

ORL is mainly composed of three stages.
, e.g., BYOL. In Stage 2, we first use the pre-trained model to retrieve KNNs for each image in the embedding space to obtain image-level visually similar pairs. We then use unsupervised region proposal algorithms (e.g., selective search) to generate rough RoIs for each image pair. Afterwards, we reuse the pre-trained model to retrieve the top-ranked RoI pairs, i.e., correspondence. We find these pairs of RoIs are almost objects or object parts. In Stage 3, with the corresponding RoI pairs discovered across images, we finally perform object-level contrastive learning using the same architecture as Stage 1.

### Stage 1: Image-level pre-training

In Stage 1, ORL pre-trains an image-level contrastive learning model. In the end of pre-training, it will extract all features in the training set and retrieve KNNs for each image in the embedding space to obtain image-level visually similar pairs.

```shell
# Train with multiple GPUs
bash tools/dist_train.sh
configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py \
${GPUS} \
--work-dir work_dirs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco/
```

or

```shell
# Train on cluster managed with slurm
GPUS_PER_NODE=${GPUS_PER_NODE} GPUS=${GPUS} CPUS_PER_TASK=${CPUS_PER_TASK} \
bash tools/slurm_train.sh ${PARTITION} ${JOB_NAME} \
configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py \
--work-dir work_dirs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco/
```

The corresponding KNN image ids will be saved as a json file `train2017_knn_instance.json` under `../data/coco/meta/`.

### Stage 2: Correspondence discovery

- **RoI generation**

ORL applies selective search to generate region proposals for all images in the training set:

```shell
# Train with single GPU
bash tools/dist_selective_search_single_gpu.sh
configs/selfsup/orl/stage2/selective_search.py \
../data/coco/meta/train2017_selective_search_proposal.json \
--work-dir work_dirs/selfsup/orl/stage2/selective_search
```

or

```shell
# Train on cluster managed with slurm
GPUS_PER_NODE=${GPUS_PER_NODE} GPUS=1 CPUS_PER_TASK=${CPUS_PER_TASK} \
bash tools/slurm_selective_search_single_gpu.sh ${PARTITION} \
configs/selfsup/orl/stage2/selective_search.py \
../data/coco/meta/train2017_selective_search_proposal.json \
--work-dir work_dirs/selfsup/orl/stage2/selective_search
```

The script and config only support single-image single-gpu inference since different images can have different number of generated region proposals by selective search, which cannot be gathered if distributed in multiple gpus. You can also directly download [here](https://drive.google.com/drive/folders/1yYsyGiDjjVSOzIUkhxwO_NitUPLC-An_?usp=sharing) if you want to skip this step.

- **RoI pair retrieval**

ORL reuses the model pre-trained in stage 1 to retrieve the top-ranked RoI pairs, i.e., correspondence.

```shell
# Train with single GPU
bash tools/dist_generate_correspondence_single_gpu.sh
configs/selfsup/orl/stage2/correspondence.py \
work_dirs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco/epoch_800.pth \
../data/coco/meta/train2017_10nn_instance.json \
../data/coco/meta/train2017_10nn_instance_correspondence.json \
--work-dir work_dirs/selfsup/orl/stage2/correspondence
```

or

```shell
# Train on cluster managed with slurm
GPUS_PER_NODE=${GPUS_PER_NODE} GPUS=1 CPUS_PER_TASK=${CPUS_PER_TASK} \
bash tools/slurm_selective_search_single_gpu.sh ${PARTITION} \
configs/selfsup/orl/stage2/correspondence.py \
work_dirs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco/epoch_800.pth \
../data/coco/meta/train2017_10nn_instance.json \
../data/coco/meta/train2017_10nn_instance_correspondence.json \
--work-dir work_dirs/selfsup/orl/stage2/correspondence
```

The script and config also only support single-image single-gpu inference since different image pairs can have different number of generated inter-RoI pairs, which cannot be gathered if distributed in multiple gpus. It will save the final correspondence json file `train2017_knn_instance_correspondence.json` under `../data/coco/meta/`.

### Stage 3: Object-level pre-training

After obtaining the correspondence file in Stage 2, ORL then performs object-level pre-training:

```shell
# Train with multiple GPUs
bash tools/dist_train.sh
configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py \
${GPUS} \
--work-dir work_dirs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco/
```

or

```shell
# Train on cluster managed with slurm
GPUS_PER_NODE=${GPUS_PER_NODE} GPUS=${GPUS} CPUS_PER_TASK=${CPUS_PER_TASK} \
bash tools/slurm_train.sh ${PARTITION} ${JOB_NAME} \
configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py \
--work-dir work_dirs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco/
```

## Models and Benchmarks

Here, we report the Low-shot image classification results of the model, which is pre-trained on COCO train2017, we report mAP for each case across five runs and the details are below:

| Self-Supervised Config | Best Layer | Weight | k=1 | k=2 | k=4 | k=8 | k=16 | k=32 | k=64 | k=96 |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | --------------------------------------------------------------------------------------------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| [stage3/orl_resnet50_8xb64-coslr-800e_coco](https://github.com/zhaozh10/mmselfsup/blob/2b14f8b06e4ba2596e90f19e4bac0c13757d80f7/configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py) | feature5 | [Pre-trained](https://drive.google.com/drive/folders/1oWzNZpoN_SPc56Gr-l3AlgGSv8jG1izG?usp=sharing) | 42.25 | 51.81 | 63.46 | 72.16 | 77.86 | 81.17 | 83.73 | 84.59 |

## Citation

```bibtex
@inproceedings{xie2021unsupervised,
title={Unsupervised Object-Level Representation Learning from Scene Images},
author={Xie, Jiahao and Zhan, Xiaohang and Liu, Ziwei and Ong, Yew Soon and Loy, Chen Change},
booktitle={NeurIPS},
year={2021}
}
```
26 changes: 26 additions & 0 deletions configs/selfsup/orl/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
Collections:
- Name: ORL
Metadata:
Training Data: COCOtrain2017
Training Techniques:
- SGD
Training Resources: 8x RTX3090 GPUs
Architecture:
- ResNet50
Paper:
URL: https://arxiv.org/abs/2106.11952
Title: "Unsupervised Object-Level Representation Learning from Scene Images"
README: configs/selfsup/ORL/README.md
Models:
- Name: orl_resnet50_8xb64-coslr-800e_coco
In Collection: ORL
Metadata:
Epochs: 800
Batch Size: 512
Results:
- Task: Self-Supervised Low-shot Image Classification
Dataset: VOC07
Metrics:
mAP: 42.25|51.81|63.46|72.16|77.86|81.17|83.73|84.59
Config: stage3/orl_resnet50_8xb64-coslr-800e_coco.py
Weights: https://drive.google.com/drive/folders/1oWzNZpoN_SPc56Gr-l3AlgGSv8jG1izG?usp=sharing
Loading