diff --git a/README.md b/README.md index 4c5fc47fc33..286a21dce83 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md). - [x] [YOLOX (ArXiv'2021)](configs/yolox/README.md) - [x] [SOLO (ECCV'2020)](configs/solo/README.md) - [x] [QueryInst (ICCV'2021)](configs/queryinst/README.md) +- [x] [TOOD (ICCV'2021)](configs/tood/README.md) Some other methods are also supported in [projects using MMDetection](./docs/en/projects.md). diff --git a/README_zh-CN.md b/README_zh-CN.md index 0d963167c40..3604cbf619c 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -146,6 +146,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope - [x] [YOLOX (ArXiv'2021)](configs/yolox/README.md) - [x] [SOLO (ECCV'2020)](configs/solo/README.md) - [x] [QueryInst (ICCV'2021)](configs/queryinst/README.md) +- [x] [TOOD (ICCV'2021)](configs/tood/README.md) 我们在[基于 MMDetection 的项目](./docs/zh_cn/projects.md)中列举了一些其他的支持的算法。 diff --git a/configs/tood/README.md b/configs/tood/README.md new file mode 100644 index 00000000000..b1522e78565 --- /dev/null +++ b/configs/tood/README.md @@ -0,0 +1,44 @@ +# TOOD: Task-aligned One-stage Object Detection + +## Abstract + + + +One-stage object detection is commonly implemented by optimizing two sub-tasks: object classification and localization, using heads with two parallel branches, which might lead to a certain level of spatial misalignment in predictions between the two tasks. In this work, we propose a Task-aligned One-stage Object Detection (TOOD) that explicitly aligns the two tasks in a learning-based manner. First, we design a novel Task-aligned Head (T-Head) which offers a better balance between learning task-interactive and task-specific features, as well as a greater flexibility to learn the alignment via a task-aligned predictor. Second, we propose Task Alignment Learning (TAL) to explicitly pull closer (or even unify) the optimal anchors for the two tasks during training via a designed sample assignment scheme and a task-aligned loss. Extensive experiments are conducted on MS-COCO, where TOOD achieves a 51.1 AP at single-model single-scale testing. This surpasses the recent one-stage detectors by a large margin, such as ATSS (47.7 AP), GFL (48.2 AP), and PAA (49.0 AP), with fewer parameters and FLOPs. Qualitative results also demonstrate the effectiveness of TOOD for better aligning the tasks of object classification and localization. + + +
+ +
+ + + + +## Citation + + + +```latex +@inproceedings{feng2021tood, + title={TOOD: Task-aligned One-stage Object Detection}, + author={Feng, Chengjian and Zhong, Yujie and Gao, Yu and Scott, Matthew R and Huang, Weilin}, + booktitle={ICCV}, + year={2021} +} +``` + +## Results and Models + +| Backbone | Style | Anchor Type | Lr schd | Multi-scale Training| Mem (GB)| Inf time (fps) | box AP | Config | Download | +|:-----------------:|:-------:|:------------:|:-------:|:-------------------:|:-------:|:--------------:|:------:|:------:|:--------:| +| R-50 | pytorch | Anchor-free | 1x | N | 4.1 | | 42.4 | [config](./tood_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425-20e20746.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425.log) | +| R-50 | pytorch | Anchor-based | 1x | N | 4.1 | | 42.4 | [config](./tood_r50_fpn_anchor_based_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105-b776c134.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105.log) | +| R-50 | pytorch | Anchor-free | 2x | Y | 4.1 | | 44.5 | [config](./tood_r50_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231-3b23174c.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231.log) | +| R-101 | pytorch | Anchor-free | 2x | Y | 6.0 | | 46.1 | [config](./tood_r101_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232-a18f53c8.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232.log) | +| R-101-dcnv2 | pytorch | Anchor-free | 2x | Y | 6.2 | | 49.3 | [config](./tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728-4a824142.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728.log) | +| X-101-64x4d | pytorch | Anchor-free | 2x | Y | 10.2 | | 47.6 | [config](./tood_x101_64x4d_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519-a4f36113.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519.log) | +| X-101-64x4d-dcnv2 | pytorch | Anchor-free | 2x | Y | | | | [config](./tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py) | [model]() | [log]() | + +[1] *1x and 2x mean the model is trained for 90K and 180K iterations, respectively.* \ +[2] *All results are obtained with a single model and without any test time data augmentation such as multi-scale, flipping and etc..* \ +[3] *`dcnv2` denotes deformable convolutional networks v2.* \ diff --git a/configs/tood/metafile.yml b/configs/tood/metafile.yml new file mode 100644 index 00000000000..27a0f8dbfc5 --- /dev/null +++ b/configs/tood/metafile.yml @@ -0,0 +1,95 @@ +Collections: + - Name: TOOD + Metadata: + Training Data: COCO + Training Techniques: + - SGD + Training Resources: 8x V100 GPUs + Architecture: + - TOOD + Paper: + URL: https://arxiv.org/abs/2108.07755 + Title: 'TOOD: Task-aligned One-stage Object Detection' + README: configs/tood/README.md + Code: + URL: https://github.com/open-mmlab/mmdetection/blob/v2.20.0/mmdet/models/detectors/tood.py#L7 + Version: v2.20.0 + +Models: + - Name: tood_r101_fpn_mstrain_2x_coco + In Collection: TOOD + Config: configs/tood/tood_r101_fpn_mstrain_2x_coco.py + Metadata: + Training Memory (GB): 6.0 + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 46.1 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232-a18f53c8.pth + + - Name: tood_x101_64x4d_fpn_mstrain_2x_coco + In Collection: TOOD + Config: configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py + Metadata: + Training Memory (GB): 10.2 + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 47.6 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519-a4f36113.pth + + - Name: tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco + In Collection: TOOD + Config: configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py + Metadata: + Training Memory (GB): 6.2 + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 49.3 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728-4a824142.pth + + - Name: tood_r50_fpn_anchor_based_1x_coco + In Collection: TOOD + Config: configs/tood/tood_r50_fpn_anchor_based_1x_coco.py + Metadata: + Training Memory (GB): 4.1 + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 42.4 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105-b776c134.pth + + - Name: tood_r50_fpn_1x_coco + In Collection: TOOD + Config: configs/tood/tood_r50_fpn_1x_coco.py + Metadata: + Training Memory (GB): 4.1 + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 42.4 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425-20e20746.pth + + - Name: tood_r50_fpn_mstrain_2x_coco + In Collection: TOOD + Config: configs/tood/tood_r50_fpn_mstrain_2x_coco.py + Metadata: + Training Memory (GB): 4.1 + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 44.5 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231-3b23174c.pth diff --git a/configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py b/configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py new file mode 100644 index 00000000000..c7f1bbcbaf1 --- /dev/null +++ b/configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py @@ -0,0 +1,7 @@ +_base_ = './tood_r101_fpn_mstrain_2x_coco.py' + +model = dict( + backbone=dict( + dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False), + stage_with_dcn=(False, True, True, True)), + bbox_head=dict(num_dcn=2)) diff --git a/configs/tood/tood_r101_fpn_mstrain_2x_coco.py b/configs/tood/tood_r101_fpn_mstrain_2x_coco.py new file mode 100644 index 00000000000..d9d2c32d8ce --- /dev/null +++ b/configs/tood/tood_r101_fpn_mstrain_2x_coco.py @@ -0,0 +1,7 @@ +_base_ = './tood_r50_fpn_mstrain_2x_coco.py' + +model = dict( + backbone=dict( + depth=101, + init_cfg=dict(type='Pretrained', + checkpoint='torchvision://resnet101'))) diff --git a/configs/tood/tood_r50_fpn_1x_coco.py b/configs/tood/tood_r50_fpn_1x_coco.py new file mode 100644 index 00000000000..35a77a400e1 --- /dev/null +++ b/configs/tood/tood_r50_fpn_1x_coco.py @@ -0,0 +1,74 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +model = dict( + type='TOOD', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='TOODHead', + num_classes=80, + in_channels=256, + stacked_convs=6, + feat_channels=256, + anchor_type='anchor_free', + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0)), + train_cfg=dict( + initial_epoch=4, + initial_assigner=dict(type='ATSSAssigner', topk=9), + assigner=dict(type='TaskAlignedAssigner', topk=13), + alpha=1, + beta=6, + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) + +# custom hooks +custom_hooks = [dict(type='SetEpochInfoHook')] diff --git a/configs/tood/tood_r50_fpn_anchor_based_1x_coco.py b/configs/tood/tood_r50_fpn_anchor_based_1x_coco.py new file mode 100644 index 00000000000..c7fbf6aff19 --- /dev/null +++ b/configs/tood/tood_r50_fpn_anchor_based_1x_coco.py @@ -0,0 +1,2 @@ +_base_ = './tood_r50_fpn_1x_coco.py' +model = dict(bbox_head=dict(anchor_type='anchor_based')) diff --git a/configs/tood/tood_r50_fpn_mstrain_2x_coco.py b/configs/tood/tood_r50_fpn_mstrain_2x_coco.py new file mode 100644 index 00000000000..157d13a4a17 --- /dev/null +++ b/configs/tood/tood_r50_fpn_mstrain_2x_coco.py @@ -0,0 +1,22 @@ +_base_ = './tood_r50_fpn_1x_coco.py' +# learning policy +lr_config = dict(step=[16, 22]) +runner = dict(type='EpochBasedRunner', max_epochs=24) +# multi-scale training +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=[(1333, 480), (1333, 800)], + multiscale_mode='range', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +data = dict(train=dict(pipeline=train_pipeline)) diff --git a/configs/tood/tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py b/configs/tood/tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py new file mode 100644 index 00000000000..47c92695a92 --- /dev/null +++ b/configs/tood/tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py @@ -0,0 +1,7 @@ +_base_ = './tood_x101_64x4d_fpn_mstrain_2x_coco.py' +model = dict( + backbone=dict( + dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False), + stage_with_dcn=(False, False, True, True), + ), + bbox_head=dict(num_dcn=2)) diff --git a/configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py b/configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py new file mode 100644 index 00000000000..842f320e839 --- /dev/null +++ b/configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py @@ -0,0 +1,16 @@ +_base_ = './tood_r50_fpn_mstrain_2x_coco.py' + +model = dict( + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type='Pretrained', checkpoint='open-mmlab://resnext101_64x4d'))) diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index 2fef4be8128..a182686491d 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -10,10 +10,12 @@ from .point_assigner import PointAssigner from .region_assigner import RegionAssigner from .sim_ota_assigner import SimOTAAssigner +from .task_aligned_assigner import TaskAlignedAssigner from .uniform_assigner import UniformAssigner __all__ = [ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner', - 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner' + 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner', + 'TaskAlignedAssigner' ] diff --git a/mmdet/core/bbox/assigners/task_aligned_assigner.py b/mmdet/core/bbox/assigners/task_aligned_assigner.py new file mode 100644 index 00000000000..1872de4a780 --- /dev/null +++ b/mmdet/core/bbox/assigners/task_aligned_assigner.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..builder import BBOX_ASSIGNERS +from ..iou_calculators import build_iou_calculator +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + +INF = 100000000 + + +@BBOX_ASSIGNERS.register_module() +class TaskAlignedAssigner(BaseAssigner): + """Task aligned assigner used in the paper: + `TOOD: Task-aligned One-stage Object Detection. + `_. + + Assign a corresponding gt bbox or background to each predicted bbox. + Each bbox will be assigned with `0` or a positive integer + indicating the ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + topk (int): number of bbox selected in each level + iou_calculator (dict): Config dict for iou calculator. + Default: dict(type='BboxOverlaps2D') + """ + + def __init__(self, topk, iou_calculator=dict(type='BboxOverlaps2D')): + assert topk >= 1 + self.topk = topk + self.iou_calculator = build_iou_calculator(iou_calculator) + + def assign(self, + pred_scores, + decode_bboxes, + anchors, + gt_bboxes, + gt_bboxes_ignore=None, + gt_labels=None, + alpha=1, + beta=6): + """Assign gt to bboxes. + + The assignment is done in following steps + + 1. compute alignment metric between all bbox (bbox of all pyramid + levels) and gt + 2. select top-k bbox as candidates for each gt + 3. limit the positive sample's center in gt (because the anchor-free + detector only can predict positive distance) + + + Args: + pred_scores (Tensor): predicted class probability, + shape(n, num_classes) + decode_bboxes (Tensor): predicted bounding boxes, shape(n, 4) + anchors (Tensor): pre-defined anchors, shape(n, 4). + gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`TaskAlignedAssignResult`: The assign result. + """ + anchors = anchors[:, :4] + num_gt, num_bboxes = gt_bboxes.size(0), anchors.size(0) + # compute alignment metric between all bbox and gt + overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach() + bbox_scores = pred_scores[:, gt_labels].detach() + # assign 0 by default + assigned_gt_inds = anchors.new_full((num_bboxes, ), + 0, + dtype=torch.long) + assign_metrics = anchors.new_zeros((num_bboxes, )) + + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = anchors.new_zeros((num_bboxes, )) + if num_gt == 0: + # No gt boxes, assign everything to background + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = anchors.new_full((num_bboxes, ), + -1, + dtype=torch.long) + assign_result = AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + assign_result.assign_metrics = assign_metrics + return assign_result + + # select top-k bboxes as candidates for each gt + alignment_metrics = bbox_scores**alpha * overlaps**beta + topk = min(self.topk, alignment_metrics.size(0)) + _, candidate_idxs = alignment_metrics.topk(topk, dim=0, largest=True) + candidate_metrics = alignment_metrics[candidate_idxs, + torch.arange(num_gt)] + is_pos = candidate_metrics > 0 + + # limit the positive sample's center in gt + anchors_cx = (anchors[:, 0] + anchors[:, 2]) / 2.0 + anchors_cy = (anchors[:, 1] + anchors[:, 3]) / 2.0 + for gt_idx in range(num_gt): + candidate_idxs[:, gt_idx] += gt_idx * num_bboxes + ep_anchors_cx = anchors_cx.view(1, -1).expand( + num_gt, num_bboxes).contiguous().view(-1) + ep_anchors_cy = anchors_cy.view(1, -1).expand( + num_gt, num_bboxes).contiguous().view(-1) + candidate_idxs = candidate_idxs.view(-1) + + # calculate the left, top, right, bottom distance between positive + # bbox center and gt side + l_ = ep_anchors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0] + t_ = ep_anchors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - ep_anchors_cx[candidate_idxs].view(-1, num_gt) + b_ = gt_bboxes[:, 3] - ep_anchors_cy[candidate_idxs].view(-1, num_gt) + is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01 + is_pos = is_pos & is_in_gts + + # if an anchor box is assigned to multiple gts, + # the one with the highest iou will be selected. + overlaps_inf = torch.full_like(overlaps, + -INF).t().contiguous().view(-1) + index = candidate_idxs.view(-1)[is_pos.view(-1)] + overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index] + overlaps_inf = overlaps_inf.view(num_gt, -1).t() + + max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1) + assigned_gt_inds[ + max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1 + assign_metrics[max_overlaps != -INF] = alignment_metrics[ + max_overlaps != -INF, argmax_overlaps[max_overlaps != -INF]] + + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + assign_result = AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + assign_result.assign_metrics = assign_metrics + return assign_result diff --git a/mmdet/core/hook/__init__.py b/mmdet/core/hook/__init__.py index 31d69a0d0c8..4d4dca6b557 100644 --- a/mmdet/core/hook/__init__.py +++ b/mmdet/core/hook/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .checkloss_hook import CheckInvalidLossHook from .ema import ExpMomentumEMAHook, LinearMomentumEMAHook +from .set_epoch_info_hook import SetEpochInfoHook from .sync_norm_hook import SyncNormHook from .sync_random_size_hook import SyncRandomSizeHook from .yolox_lrupdater_hook import YOLOXLrUpdaterHook @@ -9,5 +10,5 @@ __all__ = [ 'SyncRandomSizeHook', 'YOLOXModeSwitchHook', 'SyncNormHook', 'ExpMomentumEMAHook', 'LinearMomentumEMAHook', 'YOLOXLrUpdaterHook', - 'CheckInvalidLossHook' + 'CheckInvalidLossHook', 'SetEpochInfoHook' ] diff --git a/mmdet/core/hook/set_epoch_info_hook.py b/mmdet/core/hook/set_epoch_info_hook.py new file mode 100644 index 00000000000..c2b134ceb69 --- /dev/null +++ b/mmdet/core/hook/set_epoch_info_hook.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.parallel import is_module_wrapper +from mmcv.runner import HOOKS, Hook + + +@HOOKS.register_module() +class SetEpochInfoHook(Hook): + """Set runner's epoch information to the model.""" + + def before_train_epoch(self, runner): + epoch = runner.epoch + model = runner.model + if is_module_wrapper(model): + model = model.module + model.set_epoch(epoch) diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index e78444aca59..81d6ec2f74d 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -31,6 +31,7 @@ from .sabl_retina_head import SABLRetinaHead from .solo_head import DecoupledSOLOHead, DecoupledSOLOLightHead, SOLOHead from .ssd_head import SSDHead +from .tood_head import TOODHead from .vfnet_head import VFNetHead from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead from .yolo_head import YOLOV3Head @@ -48,5 +49,5 @@ 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', - 'DecoupledSOLOLightHead', 'LADHead' + 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead' ] diff --git a/mmdet/models/dense_heads/tood_head.py b/mmdet/models/dense_heads/tood_head.py new file mode 100644 index 00000000000..8a2a70462ef --- /dev/null +++ b/mmdet/models/dense_heads/tood_head.py @@ -0,0 +1,769 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init +from mmcv.ops import deform_conv2d +from mmcv.runner import force_fp32 + +from mmdet.core import (anchor_inside_flags, build_assigner, distance2bbox, + images_to_levels, multi_apply, reduce_mean, unmap) +from mmdet.core.utils import filter_scores_and_topk +from ..builder import HEADS, build_loss +from .atss_head import ATSSHead + + +class TaskDecomposition(nn.Module): + """Task decomposition module in task-aligned predictor of TOOD. + + Args: + feat_channels (int): Number of feature channels in TOOD head. + stacked_convs (int): Number of conv layers in TOOD head. + la_down_rate (int): Downsample rate of layer attention. + conv_cfg (dict): Config dict for convolution layer. + norm_cfg (dict): Config dict for normalization layer. + """ + + def __init__(self, + feat_channels, + stacked_convs, + la_down_rate=8, + conv_cfg=None, + norm_cfg=None): + super(TaskDecomposition, self).__init__() + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.in_channels = self.feat_channels * self.stacked_convs + self.norm_cfg = norm_cfg + self.layer_attention = nn.Sequential( + nn.Conv2d(self.in_channels, self.in_channels // la_down_rate, 1), + nn.ReLU(inplace=True), + nn.Conv2d( + self.in_channels // la_down_rate, + self.stacked_convs, + 1, + padding=0), nn.Sigmoid()) + + self.reduction_conv = ConvModule( + self.in_channels, + self.feat_channels, + 1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=norm_cfg is None) + + def init_weights(self): + for m in self.layer_attention.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + normal_init(self.reduction_conv.conv, std=0.01) + + def forward(self, feat, avg_feat=None): + b, c, h, w = feat.shape + if avg_feat is None: + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + weight = self.layer_attention(avg_feat) + + # here we first compute the product between layer attention weight and + # conv weight, and then compute the convolution between new conv weight + # and feature map, in order to save memory and FLOPs. + conv_weight = weight.reshape( + b, 1, self.stacked_convs, + 1) * self.reduction_conv.conv.weight.reshape( + 1, self.feat_channels, self.stacked_convs, self.feat_channels) + conv_weight = conv_weight.reshape(b, self.feat_channels, + self.in_channels) + feat = feat.reshape(b, self.in_channels, h * w) + feat = torch.bmm(conv_weight, feat).reshape(b, self.feat_channels, h, + w) + if self.norm_cfg is not None: + feat = self.reduction_conv.norm(feat) + feat = self.reduction_conv.activate(feat) + + return feat + + +@HEADS.register_module() +class TOODHead(ATSSHead): + """TOODHead used in `TOOD: Task-aligned One-stage Object Detection. + + `_. + + TOOD uses Task-aligned head (T-head) and is optimized by Task Alignment + Learning (TAL). + + Args: + num_dcn (int): Number of deformable convolution in the head. + Default: 0. + anchor_type (str): If set to `anchor_free`, the head will use centers + to regress bboxes. If set to `anchor_based`, the head will + regress bboxes based on anchors. Default: `anchor_free`. + initial_loss_cls (dict): Config of initial loss. + + Example: + >>> self = TOODHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_score, bbox_pred = self.forward(feats) + >>> assert len(cls_score) == len(self.scales) + """ + + def __init__(self, + num_classes, + in_channels, + num_dcn=0, + anchor_type='anchor_free', + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + **kwargs): + assert anchor_type in ['anchor_free', 'anchor_based'] + self.num_dcn = num_dcn + self.anchor_type = anchor_type + self.epoch = 0 # which would be update in SetEpochInfoHook! + super(TOODHead, self).__init__(num_classes, in_channels, **kwargs) + + if self.train_cfg: + self.initial_epoch = self.train_cfg.initial_epoch + self.initial_assigner = build_assigner( + self.train_cfg.initial_assigner) + self.initial_loss_cls = build_loss(initial_loss_cls) + self.assigner = self.initial_assigner + self.alignment_assigner = build_assigner(self.train_cfg.assigner) + self.alpha = self.train_cfg.alpha + self.beta = self.train_cfg.beta + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.inter_convs = nn.ModuleList() + for i in range(self.stacked_convs): + if i < self.num_dcn: + conv_cfg = dict(type='DCNv2', deform_groups=4) + else: + conv_cfg = self.conv_cfg + chn = self.in_channels if i == 0 else self.feat_channels + self.inter_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg)) + + self.cls_decomp = TaskDecomposition(self.feat_channels, + self.stacked_convs, + self.stacked_convs * 8, + self.conv_cfg, self.norm_cfg) + self.reg_decomp = TaskDecomposition(self.feat_channels, + self.stacked_convs, + self.stacked_convs * 8, + self.conv_cfg, self.norm_cfg) + + self.tood_cls = nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + self.tood_reg = nn.Conv2d( + self.feat_channels, self.num_base_priors * 4, 3, padding=1) + + self.cls_prob_module = nn.Sequential( + nn.Conv2d(self.feat_channels * self.stacked_convs, + self.feat_channels // 4, 1), nn.ReLU(inplace=True), + nn.Conv2d(self.feat_channels // 4, 1, 3, padding=1)) + self.reg_offset_module = nn.Sequential( + nn.Conv2d(self.feat_channels * self.stacked_convs, + self.feat_channels // 4, 1), nn.ReLU(inplace=True), + nn.Conv2d(self.feat_channels // 4, 4 * 2, 3, padding=1)) + + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.prior_generator.strides]) + + def init_weights(self): + """Initialize weights of the head.""" + bias_cls = bias_init_with_prob(0.01) + for m in self.inter_convs: + normal_init(m.conv, std=0.01) + for m in self.cls_prob_module: + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.01) + for m in self.reg_offset_module: + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + normal_init(self.cls_prob_module[-1], std=0.01, bias=bias_cls) + + self.cls_decomp.init_weights() + self.reg_decomp.init_weights() + + normal_init(self.tood_cls, std=0.01, bias=bias_cls) + normal_init(self.tood_reg, std=0.01) + + def forward(self, feats): + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * num_classes. + bbox_preds (list[Tensor]): Decoded box for all scale levels, + each is a 4D-tensor, the channels number is + num_anchors * 4. In [tl_x, tl_y, br_x, br_y] format. + """ + cls_scores = [] + bbox_preds = [] + for idx, (x, scale, stride) in enumerate( + zip(feats, self.scales, self.prior_generator.strides)): + b, c, h, w = x.shape + anchor = self.prior_generator.single_level_grid_priors( + (h, w), idx, device=x.device) + anchor = torch.cat([anchor for _ in range(b)]) + # extract task interactive features + inter_feats = [] + for inter_conv in self.inter_convs: + x = inter_conv(x) + inter_feats.append(x) + feat = torch.cat(inter_feats, 1) + + # task decomposition + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + cls_feat = self.cls_decomp(feat, avg_feat) + reg_feat = self.reg_decomp(feat, avg_feat) + + # cls prediction and alignment + cls_logits = self.tood_cls(cls_feat) + cls_prob = self.cls_prob_module(feat) + cls_score = (cls_logits.sigmoid() * cls_prob.sigmoid()).sqrt() + + # reg prediction and alignment + if self.anchor_type == 'anchor_free': + reg_dist = scale(self.tood_reg(reg_feat).exp()).float() + reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4) + reg_bbox = distance2bbox( + self.anchor_center(anchor) / stride[0], + reg_dist).reshape(b, h, w, 4).permute(0, 3, 1, + 2) # (b, c, h, w) + elif self.anchor_type == 'anchor_based': + reg_dist = scale(self.tood_reg(reg_feat)).float() + reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4) + reg_bbox = self.bbox_coder.decode(anchor, reg_dist).reshape( + b, h, w, 4).permute(0, 3, 1, 2) / stride[0] + else: + raise NotImplementedError( + f'Unknown anchor type: {self.anchor_type}.' + f'Please use `anchor_free` or `anchor_based`.') + reg_offset = self.reg_offset_module(feat) + bbox_pred = self.deform_sampling(reg_bbox.contiguous(), + reg_offset.contiguous()) + cls_scores.append(cls_score) + bbox_preds.append(bbox_pred) + return tuple(cls_scores), tuple(bbox_preds) + + def deform_sampling(self, feat, offset): + """Sampling the feature x according to offset. + + Args: + feat (Tensor): Feature + offset (Tensor): Spatial offset for for feature sampliing + """ + # it is an equivalent implementation of bilinear interpolation + b, c, h, w = feat.shape + weight = feat.new_ones(c, 1, 1, 1) + y = deform_conv2d(feat, offset, weight, 1, 0, 1, c, c) + return y + + def anchor_center(self, anchors): + """Get anchor centers from anchors. + + Args: + anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Anchor centers with shape (N, 2), "xy" format. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + return torch.stack([anchors_cx, anchors_cy], dim=-1) + + def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, + bbox_targets, alignment_metrics, stride): + """Compute loss of a single scale level. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Decoded bboxes for each scale + level with shape (N, num_anchors * 4, H, W). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors). + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + alignment_metrics (Tensor): Alignment metrics with shape + (N, num_total_anchors). + stride (tuple[int]): Downsample stride of the feature map. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, 1).reshape( + -1, self.cls_out_channels).contiguous() + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + alignment_metrics = alignment_metrics.reshape(-1) + label_weights = label_weights.reshape(-1) + targets = labels if self.epoch < self.initial_epoch else ( + labels, alignment_metrics) + cls_loss_func = self.initial_loss_cls \ + if self.epoch < self.initial_epoch else self.loss_cls + + loss_cls = cls_loss_func( + cls_score, targets, label_weights, avg_factor=1.0) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + + pos_decode_bbox_pred = pos_bbox_pred + pos_decode_bbox_targets = pos_bbox_targets / stride[0] + + # regression loss + pos_bbox_weight = self.centerness_target( + pos_anchors, pos_bbox_targets + ) if self.epoch < self.initial_epoch else alignment_metrics[ + pos_inds] + + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=pos_bbox_weight, + avg_factor=1.0) + else: + loss_bbox = bbox_pred.sum() * 0 + pos_bbox_weight = bbox_targets.new_tensor(0.) + + return loss_cls, loss_bbox, alignment_metrics.sum( + ), pos_bbox_weight.sum() + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Decoded box for each scale + level with shape (N, num_anchors * 4, H, W) in + [tl_x, tl_y, br_x, br_y] format. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = len(img_metas) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + flatten_cls_scores = torch.cat([ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_score in cls_scores + ], 1) + flatten_bbox_preds = torch.cat([ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) * stride[0] + for bbox_pred, stride in zip(bbox_preds, + self.prior_generator.strides) + ], 1) + + cls_reg_targets = self.get_targets( + flatten_cls_scores, + flatten_bbox_preds, + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + alignment_metrics_list) = cls_reg_targets + + losses_cls, losses_bbox,\ + cls_avg_factors, bbox_avg_factors = multi_apply( + self.loss_single, + anchor_list, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + alignment_metrics_list, + self.prior_generator.strides) + + cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() + losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) + + bbox_avg_factor = reduce_mean( + sum(bbox_avg_factors)).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + + def _get_bboxes_single(self, + cls_score_list, + bbox_pred_list, + score_factor_list, + mlvl_priors, + img_meta, + cfg, + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bboxes, 5], where the first 4 columns are bounding \ + box positions (tl_x, tl_y, br_x, br_y) and the 5-th \ + column are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bboxes]. + """ + + cfg = self.test_cfg if cfg is None else cfg + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_labels = [] + for cls_score, bbox_pred, priors, stride in zip( + cls_score_list, bbox_pred_list, mlvl_priors, + self.prior_generator.strides): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) * stride[0] + scores = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + results = filter_scores_and_topk( + scores, cfg.score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + scores, labels, keep_idxs, filtered_results = results + + bboxes = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes, + img_meta['scale_factor'], cfg, rescale, + with_nms, None, **kwargs) + + def get_targets(self, + cls_scores, + bbox_preds, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + cls_scores (Tensor): Classification predictions of images, + a 3D-Tensor with shape [num_imgs, num_priors, num_classes]. + bbox_preds (Tensor): Decoded bboxes predictions of one image, + a 3D-Tensor with shape [num_imgs, num_priors, 4] in [tl_x, + tl_y, br_x, br_y] format. + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_labels_list (list[Tensor]): Ground truth labels of each box. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: a tuple containing learning targets. + + - anchors_list (list[list[Tensor]]): Anchors of each level. + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - norm_alignment_metrics_list (list[Tensor]): Normalized + alignment metrics of each level. + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + # anchor_list: list(b * [-1, 4]) + + if self.epoch < self.initial_epoch: + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply( + super()._get_target_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + all_assign_metrics = [ + weight[..., 0] for weight in all_bbox_weights + ] + else: + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_assign_metrics) = multi_apply( + self._get_target_single, + cls_scores, + bbox_preds, + anchor_list, + valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + norm_alignment_metrics_list = images_to_levels(all_assign_metrics, + num_level_anchors) + + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, norm_alignment_metrics_list) + + def _get_target_single(self, + cls_scores, + bbox_preds, + flat_anchors, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True): + """Compute regression, classification targets for anchors in a single + image. + + Args: + cls_scores (list(Tensor)): Box scores for each image. + bbox_preds (list(Tensor)): Box energies / deltas for each image. + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors ,4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + img_meta (dict): Meta info of the image. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + anchors (Tensor): All anchors in the image with shape (N, 4). + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + norm_alignment_metrics (Tensor): Normalized alignment metrics + of all priors in the image with shape (N,). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 7 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + assign_result = self.alignment_assigner.assign( + cls_scores[inside_flags, :], bbox_preds[inside_flags, :], anchors, + gt_bboxes, gt_bboxes_ignore, gt_labels, self.alpha, self.beta) + assign_ious = assign_result.max_overlaps + assign_metrics = assign_result.assign_metrics + + sampling_result = self.sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + norm_alignment_metrics = anchors.new_zeros( + num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + # point-based + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class since v2.5.0 + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + class_assigned_gt_inds = torch.unique( + sampling_result.pos_assigned_gt_inds) + for gt_inds in class_assigned_gt_inds: + gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds == + gt_inds] + pos_alignment_metrics = assign_metrics[gt_class_inds] + pos_ious = assign_ious[gt_class_inds] + pos_norm_alignment_metrics = pos_alignment_metrics / ( + pos_alignment_metrics.max() + 10e-8) * pos_ious.max() + norm_alignment_metrics[gt_class_inds] = pos_norm_alignment_metrics + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + norm_alignment_metrics = unmap(norm_alignment_metrics, + num_total_anchors, inside_flags) + return (anchors, labels, label_weights, bbox_targets, + norm_alignment_metrics) diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 637434a60b5..456b8d424fb 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -32,6 +32,7 @@ from .single_stage import SingleStageDetector from .solo import SOLO from .sparse_rcnn import SparseRCNN +from .tood import TOOD from .trident_faster_rcnn import TridentFasterRCNN from .two_stage import TwoStageDetector from .vfnet import VFNet @@ -48,5 +49,5 @@ 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', - 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD' + 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD' ] diff --git a/mmdet/models/detectors/tood.py b/mmdet/models/detectors/tood.py new file mode 100644 index 00000000000..7dd18c3c96a --- /dev/null +++ b/mmdet/models/detectors/tood.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..builder import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module() +class TOOD(SingleStageDetector): + r"""Implementation of `TOOD: Task-aligned One-stage Object Detection. + `_.""" + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(TOOD, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained, init_cfg) + + def set_epoch(self, epoch): + self.bbox_head.epoch = epoch diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py index 92909117ced..6c20fddd56f 100644 --- a/mmdet/models/losses/focal_loss.py +++ b/mmdet/models/losses/focal_loss.py @@ -57,6 +57,59 @@ def py_sigmoid_focal_loss(pred, return loss +def py_focal_loss_with_prob(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + Different from `py_sigmoid_focal_loss`, this function accepts probability + as input. + + Args: + pred (torch.Tensor): The prediction probability with shape (N, C), + C is the number of classes. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] + + target = target.type_as(pred) + pt = (1 - pred) * target + pred * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy( + pred, target, reduction='none') * focal_weight + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.view(loss.size(0), -1) + assert weight.ndim == loss.ndim + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + def sigmoid_focal_loss(pred, target, weight=None, @@ -111,7 +164,8 @@ def __init__(self, gamma=2.0, alpha=0.25, reduction='mean', - loss_weight=1.0): + loss_weight=1.0, + activated=False): """`Focal Loss `_ Args: @@ -125,6 +179,10 @@ def __init__(self, a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". loss_weight (float, optional): Weight of loss. Defaults to 1.0. + activated (bool, optional): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. """ super(FocalLoss, self).__init__() assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' @@ -133,6 +191,7 @@ def __init__(self, self.alpha = alpha self.reduction = reduction self.loss_weight = loss_weight + self.activated = activated def forward(self, pred, @@ -160,13 +219,16 @@ def forward(self, reduction = ( reduction_override if reduction_override else self.reduction) if self.use_sigmoid: - if torch.cuda.is_available() and pred.is_cuda: - calculate_loss_func = sigmoid_focal_loss + if self.activated: + calculate_loss_func = py_focal_loss_with_prob else: - num_classes = pred.size(1) - target = F.one_hot(target, num_classes=num_classes + 1) - target = target[:, :num_classes] - calculate_loss_func = py_sigmoid_focal_loss + if torch.cuda.is_available() and pred.is_cuda: + calculate_loss_func = sigmoid_focal_loss + else: + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] + calculate_loss_func = py_sigmoid_focal_loss loss_cls = self.loss_weight * calculate_loss_func( pred, diff --git a/mmdet/models/losses/gfocal_loss.py b/mmdet/models/losses/gfocal_loss.py index a7a1b765f4f..0e8d26373f8 100644 --- a/mmdet/models/losses/gfocal_loss.py +++ b/mmdet/models/losses/gfocal_loss.py @@ -52,6 +52,52 @@ def quality_focal_loss(pred, target, beta=2.0): return loss +@weighted_loss +def quality_focal_loss_with_prob(pred, target, beta=2.0): + r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning + Qualified and Distributed Bounding Boxes for Dense Object Detection + `_. + Different from `quality_focal_loss`, this function accepts probability + as input. + + Args: + pred (torch.Tensor): Predicted joint representation of classification + and quality (IoU) estimation with shape (N, C), C is the number of + classes. + target (tuple([torch.Tensor])): Target category label with shape (N,) + and target quality label with shape (N,). + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + + Returns: + torch.Tensor: Loss tensor with shape (N,). + """ + assert len(target) == 2, """target for QFL must be a tuple of two elements, + including category label and quality label, respectively""" + # label denotes the category id, score denotes the quality score + label, score = target + + # negatives are supervised by 0 quality score + pred_sigmoid = pred + scale_factor = pred_sigmoid + zerolabel = scale_factor.new_zeros(pred.shape) + loss = F.binary_cross_entropy( + pred, zerolabel, reduction='none') * scale_factor.pow(beta) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = pred.size(1) + pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1) + pos_label = label[pos].long() + # positives are supervised by bbox quality (IoU) score + scale_factor = score[pos] - pred_sigmoid[pos, pos_label] + loss[pos, pos_label] = F.binary_cross_entropy( + pred[pos, pos_label], score[pos], + reduction='none') * scale_factor.abs().pow(beta) + + loss = loss.sum(dim=1, keepdim=False) + return loss + + @mmcv.jit(derivate=True, coderize=True) @weighted_loss def distribution_focal_loss(pred, label): @@ -91,19 +137,25 @@ class QualityFocalLoss(nn.Module): Defaults to 2.0. reduction (str): Options are "none", "mean" and "sum". loss_weight (float): Loss weight of current loss. + activated (bool, optional): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. """ def __init__(self, use_sigmoid=True, beta=2.0, reduction='mean', - loss_weight=1.0): + loss_weight=1.0, + activated=False): super(QualityFocalLoss, self).__init__() assert use_sigmoid is True, 'Only sigmoid in QFL supported now.' self.use_sigmoid = use_sigmoid self.beta = beta self.reduction = reduction self.loss_weight = loss_weight + self.activated = activated def forward(self, pred, @@ -131,7 +183,11 @@ def forward(self, reduction = ( reduction_override if reduction_override else self.reduction) if self.use_sigmoid: - loss_cls = self.loss_weight * quality_focal_loss( + if self.activated: + calculate_loss_func = quality_focal_loss_with_prob + else: + calculate_loss_func = quality_focal_loss + loss_cls = self.loss_weight * calculate_loss_func( pred, target, weight, diff --git a/model-index.yml b/model-index.yml index a2dcea4d3a5..900b9385242 100644 --- a/model-index.yml +++ b/model-index.yml @@ -57,6 +57,7 @@ Import: - configs/ssd/metafile.yml - configs/swin/metafile.yml - configs/tridentnet/metafile.yml + - configs/tood/metafile.yml - configs/vfnet/metafile.yml - configs/yolact/metafile.yml - configs/yolo/metafile.yml diff --git a/setup.cfg b/setup.cfg index b0171d35b2f..18adf687165 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,4 +15,4 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids +ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood diff --git a/tests/test_models/test_dense_heads/test_tood_head.py b/tests/test_models/test_dense_heads/test_tood_head.py new file mode 100644 index 00000000000..2174cc0e4db --- /dev/null +++ b/tests/test_models/test_dense_heads/test_tood_head.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import torch + +from mmdet.models.dense_heads import TOODHead + + +def test_paa_head_loss(): + """Tests paa head loss when truth is empty and non-empty.""" + + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + train_cfg = mmcv.Config( + dict( + initial_epoch=4, + initial_assigner=dict(type='ATSSAssigner', topk=9), + assigner=dict(type='TaskAlignedAssigner', topk=13), + alpha=1, + beta=6, + allowed_border=-1, + pos_weight=-1, + debug=False)) + test_cfg = mmcv.Config( + dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + # since Focal Loss is not supported on CPU + self = TOODHead( + num_classes=80, + in_channels=1, + stacked_convs=6, + feat_channels=256, + anchor_type='anchor_free', + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + train_cfg=train_cfg, + test_cfg=test_cfg) + self.init_weights() + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [8, 16, 32, 64, 128] + ] + cls_scores, bbox_preds = self(feat) + + # test initial assigner and losses + self.epoch = 0 + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = empty_gt_losses['loss_cls'] + empty_box_loss = empty_gt_losses['loss_bbox'] + assert sum(empty_cls_loss).item() > 0, 'cls loss should be non-zero' + assert sum(empty_box_loss).item() == 0, ( + 'there should be no box loss when there are no true boxes') + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = one_gt_losses['loss_cls'] + onegt_box_loss = one_gt_losses['loss_bbox'] + assert sum(onegt_cls_loss).item() > 0, 'cls loss should be non-zero' + assert sum(onegt_box_loss).item() > 0, 'box loss should be non-zero' + + # test task alignment assigner and losses + self.epoch = 10 + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = empty_gt_losses['loss_cls'] + empty_box_loss = empty_gt_losses['loss_bbox'] + assert sum(empty_cls_loss).item() > 0, 'cls loss should be non-zero' + assert sum(empty_box_loss).item() == 0, ( + 'there should be no box loss when there are no true boxes') + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = one_gt_losses['loss_cls'] + onegt_box_loss = one_gt_losses['loss_bbox'] + assert sum(onegt_cls_loss).item() > 0, 'cls loss should be non-zero' + assert sum(onegt_box_loss).item() > 0, 'box loss should be non-zero' diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index 3a42d35dcc8..ca82aeda127 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -5,12 +5,13 @@ pytest tests/test_utils/test_assigner.py xdoctest tests/test_utils/test_assigner.py zero """ +import pytest import torch from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, CenterRegionAssigner, HungarianAssigner, MaxIoUAssigner, PointAssigner, - UniformAssigner) + TaskAlignedAssigner, UniformAssigner) def test_max_iou_assigner(): @@ -496,3 +497,45 @@ def test_uniform_assigner_with_empty_boxes(): # Test without gt_labels assign_result = self.assign(pred_bbox, anchor, gt_bboxes, gt_labels=None) assert len(assign_result.gt_inds) == 0 + + +def test_task_aligned_assigner(): + with pytest.raises(AssertionError): + TaskAlignedAssigner(topk=0) + + self = TaskAlignedAssigner(topk=13) + pred_score = torch.FloatTensor([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], + [0.4, 0.5]]) + pred_bbox = torch.FloatTensor([ + [1, 1, 12, 8], + [4, 4, 20, 20], + [1, 5, 15, 15], + [30, 5, 32, 42], + ]) + anchor = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([0, 1]) + assign_result = self.assign( + pred_score, + pred_bbox, + anchor, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels) + assert len(assign_result.gt_inds) == 4 + assert len(assign_result.labels) == 4 + + # test empty gt + gt_bboxes = torch.empty(0, 4) + gt_labels = torch.empty(0, 2) + assign_result = self.assign( + pred_score, pred_bbox, anchor, gt_bboxes=gt_bboxes) + expected_gt_inds = torch.LongTensor([0, 0, 0, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) diff --git a/tests/test_utils/test_hook.py b/tests/test_utils/test_hook.py index afd176788f0..43fab670a05 100644 --- a/tests/test_utils/test_hook.py +++ b/tests/test_utils/test_hook.py @@ -323,3 +323,32 @@ def train_step(self, x, optimizer, **kwargs): with pytest.raises(AssertionError): runner.run([loader], [('train', 1)]) shutil.rmtree(runner.work_dir) + + +def test_set_epoch_info_hook(): + """Test SetEpochInfoHook.""" + + class DemoModel(nn.Module): + + def __init__(self): + super().__init__() + self.epoch = 0 + self.linear = nn.Linear(2, 1) + + def forward(self, x): + return self.linear(x) + + def train_step(self, x, optimizer, **kwargs): + return dict(loss=self(x)) + + def set_epoch(self, epoch): + self.epoch = epoch + + loader = DataLoader(torch.ones((5, 2))) + runner = _build_demo_runner(max_epochs=3) + + demo_model = DemoModel() + runner.model = demo_model + runner.register_hook_from_cfg(dict(type='SetEpochInfoHook')) + runner.run([loader], [('train', 1)]) + assert demo_model.epoch == 2