Skip to content

Commit

Permalink
[Feature] Support TOOD: Task-aligned One-stage Object Detection (ICCV…
Browse files Browse the repository at this point in the history
… 2021 Oral) (open-mmlab#6746)

* [Feature] Support TOOD.

* update

* use assign result

* use assign result

* clean assigner

* add config

* add tood head unit test and fix device bug

* test assigner and fix empty gt error

* test hook

* add anchor-based cfg and readme

* update readme

* resolve comments

* resolve comment

* add metafile

* fix model index

* copyright

* resolve comments

* resolve comments
  • Loading branch information
RangiLyu authored Dec 24, 2021
1 parent d3fcce5 commit ec42990
Show file tree
Hide file tree
Showing 26 changed files with 1,573 additions and 15 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
- [x] [YOLOX (ArXiv'2021)](configs/yolox/README.md)
- [x] [SOLO (ECCV'2020)](configs/solo/README.md)
- [x] [QueryInst (ICCV'2021)](configs/queryinst/README.md)
- [x] [TOOD (ICCV'2021)](configs/tood/README.md)
</details>

Some other methods are also supported in [projects using MMDetection](./docs/en/projects.md).
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
- [x] [YOLOX (ArXiv'2021)](configs/yolox/README.md)
- [x] [SOLO (ECCV'2020)](configs/solo/README.md)
- [x] [QueryInst (ICCV'2021)](configs/queryinst/README.md)
- [x] [TOOD (ICCV'2021)](configs/tood/README.md)
</details>

我们在[基于 MMDetection 的项目](./docs/zh_cn/projects.md)中列举了一些其他的支持的算法。
Expand Down
44 changes: 44 additions & 0 deletions configs/tood/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# TOOD: Task-aligned One-stage Object Detection

## Abstract

<!-- [ABSTRACT] -->

One-stage object detection is commonly implemented by optimizing two sub-tasks: object classification and localization, using heads with two parallel branches, which might lead to a certain level of spatial misalignment in predictions between the two tasks. In this work, we propose a Task-aligned One-stage Object Detection (TOOD) that explicitly aligns the two tasks in a learning-based manner. First, we design a novel Task-aligned Head (T-Head) which offers a better balance between learning task-interactive and task-specific features, as well as a greater flexibility to learn the alignment via a task-aligned predictor. Second, we propose Task Alignment Learning (TAL) to explicitly pull closer (or even unify) the optimal anchors for the two tasks during training via a designed sample assignment scheme and a task-aligned loss. Extensive experiments are conducted on MS-COCO, where TOOD achieves a 51.1 AP at single-model single-scale testing. This surpasses the recent one-stage detectors by a large margin, such as ATSS (47.7 AP), GFL (48.2 AP), and PAA (49.0 AP), with fewer parameters and FLOPs. Qualitative results also demonstrate the effectiveness of TOOD for better aligning the tasks of object classification and localization.

<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/12907710/145400075-e08191f5-8afa-4335-9b3b-27926fc9a26e.png"/>
</div>

<!-- [PAPER_TITLE: TOOD: Task-aligned One-stage Object Detection] -->
<!-- [PAPER_URL: https://arxiv.org/abs/2108.07755] -->

## Citation

<!-- [ALGORITHM] -->

```latex
@inproceedings{feng2021tood,
title={TOOD: Task-aligned One-stage Object Detection},
author={Feng, Chengjian and Zhong, Yujie and Gao, Yu and Scott, Matthew R and Huang, Weilin},
booktitle={ICCV},
year={2021}
}
```

## Results and Models

| Backbone | Style | Anchor Type | Lr schd | Multi-scale Training| Mem (GB)| Inf time (fps) | box AP | Config | Download |
|:-----------------:|:-------:|:------------:|:-------:|:-------------------:|:-------:|:--------------:|:------:|:------:|:--------:|
| R-50 | pytorch | Anchor-free | 1x | N | 4.1 | | 42.4 | [config](./tood_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425-20e20746.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425.log) |
| R-50 | pytorch | Anchor-based | 1x | N | 4.1 | | 42.4 | [config](./tood_r50_fpn_anchor_based_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105-b776c134.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105.log) |
| R-50 | pytorch | Anchor-free | 2x | Y | 4.1 | | 44.5 | [config](./tood_r50_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231-3b23174c.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231.log) |
| R-101 | pytorch | Anchor-free | 2x | Y | 6.0 | | 46.1 | [config](./tood_r101_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232-a18f53c8.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232.log) |
| R-101-dcnv2 | pytorch | Anchor-free | 2x | Y | 6.2 | | 49.3 | [config](./tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728-4a824142.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728.log) |
| X-101-64x4d | pytorch | Anchor-free | 2x | Y | 10.2 | | 47.6 | [config](./tood_x101_64x4d_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519-a4f36113.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519.log) |
| X-101-64x4d-dcnv2 | pytorch | Anchor-free | 2x | Y | | | | [config](./tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py) | [model]() &#124; [log]() |

[1] *1x and 2x mean the model is trained for 90K and 180K iterations, respectively.* \
[2] *All results are obtained with a single model and without any test time data augmentation such as multi-scale, flipping and etc..* \
[3] *`dcnv2` denotes deformable convolutional networks v2.* \
95 changes: 95 additions & 0 deletions configs/tood/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
Collections:
- Name: TOOD
Metadata:
Training Data: COCO
Training Techniques:
- SGD
Training Resources: 8x V100 GPUs
Architecture:
- TOOD
Paper:
URL: https://arxiv.org/abs/2108.07755
Title: 'TOOD: Task-aligned One-stage Object Detection'
README: configs/tood/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/v2.20.0/mmdet/models/detectors/tood.py#L7
Version: v2.20.0

Models:
- Name: tood_r101_fpn_mstrain_2x_coco
In Collection: TOOD
Config: configs/tood/tood_r101_fpn_mstrain_2x_coco.py
Metadata:
Training Memory (GB): 6.0
Epochs: 24
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 46.1
Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232-a18f53c8.pth

- Name: tood_x101_64x4d_fpn_mstrain_2x_coco
In Collection: TOOD
Config: configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py
Metadata:
Training Memory (GB): 10.2
Epochs: 24
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 47.6
Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519-a4f36113.pth

- Name: tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco
In Collection: TOOD
Config: configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py
Metadata:
Training Memory (GB): 6.2
Epochs: 24
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 49.3
Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728-4a824142.pth

- Name: tood_r50_fpn_anchor_based_1x_coco
In Collection: TOOD
Config: configs/tood/tood_r50_fpn_anchor_based_1x_coco.py
Metadata:
Training Memory (GB): 4.1
Epochs: 12
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 42.4
Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105-b776c134.pth

- Name: tood_r50_fpn_1x_coco
In Collection: TOOD
Config: configs/tood/tood_r50_fpn_1x_coco.py
Metadata:
Training Memory (GB): 4.1
Epochs: 12
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 42.4
Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425-20e20746.pth

- Name: tood_r50_fpn_mstrain_2x_coco
In Collection: TOOD
Config: configs/tood/tood_r50_fpn_mstrain_2x_coco.py
Metadata:
Training Memory (GB): 4.1
Epochs: 24
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 44.5
Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231-3b23174c.pth
7 changes: 7 additions & 0 deletions configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = './tood_r101_fpn_mstrain_2x_coco.py'

model = dict(
backbone=dict(
dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
bbox_head=dict(num_dcn=2))
7 changes: 7 additions & 0 deletions configs/tood/tood_r101_fpn_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = './tood_r50_fpn_mstrain_2x_coco.py'

model = dict(
backbone=dict(
depth=101,
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')))
74 changes: 74 additions & 0 deletions configs/tood/tood_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='TOOD',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='TOODHead',
num_classes=80,
in_channels=256,
stacked_convs=6,
feat_channels=256,
anchor_type='anchor_free',
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
initial_loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
activated=True, # use probability instead of logit as input
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_cls=dict(
type='QualityFocalLoss',
use_sigmoid=True,
activated=True, # use probability instead of logit as input
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0)),
train_cfg=dict(
initial_epoch=4,
initial_assigner=dict(type='ATSSAssigner', topk=9),
assigner=dict(type='TaskAlignedAssigner', topk=13),
alpha=1,
beta=6,
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

# custom hooks
custom_hooks = [dict(type='SetEpochInfoHook')]
2 changes: 2 additions & 0 deletions configs/tood/tood_r50_fpn_anchor_based_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './tood_r50_fpn_1x_coco.py'
model = dict(bbox_head=dict(anchor_type='anchor_based'))
22 changes: 22 additions & 0 deletions configs/tood/tood_r50_fpn_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_base_ = './tood_r50_fpn_1x_coco.py'
# learning policy
lr_config = dict(step=[16, 22])
runner = dict(type='EpochBasedRunner', max_epochs=24)
# multi-scale training
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 480), (1333, 800)],
multiscale_mode='range',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
data = dict(train=dict(pipeline=train_pipeline))
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = './tood_x101_64x4d_fpn_mstrain_2x_coco.py'
model = dict(
backbone=dict(
dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, False, True, True),
),
bbox_head=dict(num_dcn=2))
16 changes: 16 additions & 0 deletions configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_base_ = './tood_r50_fpn_mstrain_2x_coco.py'

model = dict(
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint='open-mmlab://resnext101_64x4d')))
4 changes: 3 additions & 1 deletion mmdet/core/bbox/assigners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from .point_assigner import PointAssigner
from .region_assigner import RegionAssigner
from .sim_ota_assigner import SimOTAAssigner
from .task_aligned_assigner import TaskAlignedAssigner
from .uniform_assigner import UniformAssigner

__all__ = [
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner'
'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner',
'TaskAlignedAssigner'
]
Loading

0 comments on commit ec42990

Please sign in to comment.