diff --git a/README.md b/README.md index dde36eeae..eed08502b 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,7 @@ A summary can be found in the [Model Zoo](docs/en/model_zoo.md) page. - [x] [PSC](configs/psc/README.md) (CVPR'2023) - [x] [RTMDet](configs/rotated_rtmdet/README.md) (arXiv) - [x] [H2RBox-v2](configs/h2rbox_v2/README.md) (NeurIPS'2023) +- [x] [Point2RBox](configs/point2rbox/README.md) (CVPR'2024) diff --git a/configs/h2rbox_v2/README.md b/configs/h2rbox_v2/README.md index dfa93d3c3..47cef9fa3 100644 --- a/configs/h2rbox_v2/README.md +++ b/configs/h2rbox_v2/README.md @@ -44,9 +44,9 @@ HRSC ``` @inproceedings{yu2023h2rboxv2, - title={H2RBox-v2: Incorporating Symmetry for Boosting Horizontal Box Supervised Oriented Object Detection}, - author={Yi Yu and Xue Yang and Qingyun Li and Yue Zhou and and Feipeng Da and Junchi Yan}, - year={2023}, - booktitle={Advances in Neural Information Processing Systems} +title={H2RBox-v2: Incorporating Symmetry for Boosting Horizontal Box Supervised Oriented Object Detection}, +author={Yi Yu and Xue Yang and Qingyun Li and Yue Zhou and and Feipeng Da and Junchi Yan}, +year={2023}, +booktitle={Advances in Neural Information Processing Systems} } ``` diff --git a/configs/point2rbox/README.md b/configs/point2rbox/README.md new file mode 100644 index 000000000..d6a1338b3 --- /dev/null +++ b/configs/point2rbox/README.md @@ -0,0 +1,48 @@ +# Point2RBox + +> [Point2RBox: Combine Knowledge from Synthetic Visual Patterns for End-to-end Oriented Object Detection with Single Point Supervision](https://arxiv.org/pdf/2311.14758) + + + +## Abstract + +
+ +
+ +With the rapidly increasing demand for oriented object detection (OOD), recent research involving weakly-supervised detectors for learning rotated box (RBox) from the horizontal box (HBox) has attracted more and more attention. In this paper, we explore a more challenging yet label-efficient setting, namely single point-supervised OOD, and present our approach called Point2RBox. Specifically, we propose to leverage two principles: 1) Synthetic pattern knowledge combination: By sampling around each labelled point on the image, we transfer the object feature to synthetic visual patterns with the known bounding box to provide the knowledge for box regression. 2) Transform self-supervision: With a transformed input image (e.g. scaled/rotated), the output RBoxes are trained to follow the same transformation so that the network can perceive the relative size/rotation between objects. The detector is further enhanced by a few devised techniques to cope with peripheral issues, e.g. the anchor/layer assignment as the size of the object is not available in our point supervision setting. To our best knowledge, Point2RBox is the first end-to-end solution for point-supervised OOD. In particular, our method uses a lightweight paradigm, yet it achieves a competitive performance among point-supervised alternatives, 41.05%/27.62%/80.01% on DOTA/DIOR/HRSC datasets. + +## Basic patterns + +Extract [basic_patterns.zip](https://github.com/open-mmlab/mmrotate/files/13816461/basic_patterns.zip) to data folder. The path can also be modified in config files. + +## Results and models + +DOTA1.0 + +| Backbone | AP50 | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download | +| :----------------------: | :---: | :-----: | :------: | :------------: | :-: | :--------: | :-------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ResNet50 (1024,1024,200) | 41.87 | 1x | 16.12 | 111.7 | - | 2 | [point2rbox-yolof-dota](./point2rbox-yolof-dota.py) | [model](https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-dota/point2rbox-yolof-dota-c94da82d.pth) \| [log](https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-dota/point2rbox-yolof-dota.json) | + +DIOR + +| Backbone | AP50 | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download | +| :----------------: | :---: | :-----: | :------: | :------------: | :-: | :--------: | :-------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ResNet50 (800,800) | 27.34 | 1x | 10.38 | 127.3 | - | 2 | [point2rbox-yolof-dior](./point2rbox-yolof-dior.py) | [model](https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-dior/point2rbox-yolof-dior-f4f724df.pth) \| [log](https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-dior/point2rbox-yolof-dior.json) | + +HRSC + +| Backbone | AP50 | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download | +| :----------------: | :---: | :-----: | :------: | :------------: | :-: | :--------: | :-------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ResNet50 (800,800) | 79.40 | 6x | 9.60 | 136.9 | - | 2 | [point2rbox-yolof-hrsc](./point2rbox-yolof-hrsc.py) | [model](https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-hrsc/point2rbox-yolof-hrsc-9d096323.pth) \| [log](https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-hrsc/point2rbox-yolof-hrsc.json) | + +## Citation + +``` +@inproceedings{yu2024point2rbox, +title={Point2RBox: Combine Knowledge from Synthetic Visual Patterns for End-to-end Oriented Object Detection with Single Point Supervision}, +author={Yi Yu and Xue Yang and Qingyun Li and Feipeng Da and Jifeng Dai and Yu Qiao and Junchi Yan}, +year={2024}, +booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition} +} +``` diff --git a/configs/point2rbox/metafile.yml b/configs/point2rbox/metafile.yml new file mode 100755 index 000000000..7bbc094b4 --- /dev/null +++ b/configs/point2rbox/metafile.yml @@ -0,0 +1,50 @@ +Collections: +- Name: point2rbox + Metadata: + Training Data: DOTAv1.0 + Training Techniques: + - AdamW + Training Resources: 1x GeForce RTX 4090 + Architecture: + - ResNet + Paper: + URL: https://arxiv.org/pdf/2311.14758.pdf + Title: 'Point2RBox: Combine Knowledge from Synthetic Visual Patterns for End-to-end Oriented Object Detection with Single Point Supervision' + README: configs/point2rbox/README.md + +Models: + - Name: point2rbox-yolof-dota + In Collection: point2rbox + Config: configs/point2rbox/point2rbox-yolof-dota.py + Metadata: + Training Data: DOTAv1.0 + Results: + - Task: Oriented Object Detection + Dataset: DOTAv1.0 + Metrics: + mAP: 41.87 + Weights: https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-dota/point2rbox-yolof-dota-c94da82d.pth + + - Name: point2rbox-yolof-dior + In Collection: point2rbox + Config: configs/point2rbox/point2rbox-yolof-dior.py + Metadata: + Training Data: DIOR + Results: + - Task: Oriented Object Detection + Dataset: DIOR + Metrics: + mAP: 27.34 + Weights: https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-dior/point2rbox-yolof-dior-f4f724df.pth + + - Name: point2rbox-yolof-hrsc + In Collection: point2rbox + Config: configs/point2rbox/point2rbox-yolof-hrsc.py + Metadata: + Training Data: HRSC + Results: + - Task: Oriented Object Detection + Dataset: HRSC + Metrics: + mAP: 79.40 + Weights: https://download.openmmlab.com/mmrotate/v1.0/point2rbox/point2rbox-yolof-hrsc/point2rbox-yolof-hrsc-9d096323.pth diff --git a/configs/point2rbox/point2rbox-yolof-dior.py b/configs/point2rbox/point2rbox-yolof-dior.py new file mode 100755 index 000000000..27e412f96 --- /dev/null +++ b/configs/point2rbox/point2rbox-yolof-dior.py @@ -0,0 +1,156 @@ +_base_ = [ + '../_base_/datasets/dior.py', '../_base_/schedules/schedule_1x.py', + '../_base_/default_runtime.py' +] +model = dict( + type='Point2RBoxYOLOF', + crop_size=(800, 800), + prob_rot=0.95 * 0.7, + prob_flp=0.05 * 0.7, + sca_fact=1.0, + sca_range=(0.5, 1.5), + basic_pattern='data/basic_patterns/dior', + dense_cls=[], + use_setrc=False, + use_setsk=True, + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False, + pad_size_divisor=32), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + strides=(1, 2, 2, 1), # DC5 + dilations=(1, 1, 1, 2), + out_indices=(3, ), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='caffe', + init_cfg=dict( + type='Pretrained', + checkpoint='open-mmlab://detectron/resnet50_caffe')), + neck=dict( + type='mmdet.DilatedEncoder', + in_channels=2048, + out_channels=512, + block_mid_channels=128, + num_residual_blocks=4, + block_dilations=[2, 4, 6, 8]), + bbox_head=dict( + type='Point2RBoxYOLOFHead', + num_classes=20, + in_channels=512, + reg_decoded_bbox=True, + num_cls_convs=4, + num_reg_convs=8, + use_objectness=False, + agnostic_cls=[2, 5, 9, 14, 15], + square_cls=[], + anchor_generator=dict( + type='mmdet.AnchorGenerator', + ratios=[1.0], + scales=[8, 8, 8, 8, 8, 8, 8], + strides=[16]), + bbox_coder=dict( + type='mmdet.DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1., 1., 1., 1.], + add_ctr_clamp=True, + ctr_clamp=16), + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0), + loss_angle=dict(type='mmdet.L1Loss', loss_weight=0.3), + loss_scale_ss=dict(type='mmdet.GIoULoss', loss_weight=0.02)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='Point2RBoxAssigner', + pos_ignore_thr=0.15, + neg_ignore_thr=0.7, + match_times=4), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000)) + +# optimizer +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type='AdamW', + lr=0.00005, + betas=(0.9, 0.999), + weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0., custom_keys={'backbone': dict(lr_mult=1. / 3)})) + +train_pipeline = [ + dict(type='mmdet.LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), + dict(type='mmdet.FixShapeResize', width=800, height=800, keep_ratio=True), + dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), + dict(type='RBox2Point'), + dict( + type='mmdet.RandomFlip', + prob=0.75, + direction=['horizontal', 'vertical', 'diagonal']), + dict(type='RandomRotate', prob=1, angle_range=180), + dict(type='mmdet.RandomShift', prob=0.5, max_shift_px=16), + dict(type='mmdet.PackDetInputs') +] + +dataset_type = 'DIORDataset' +data_root = 'data/dior/' +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=None, + dataset=dict( + type='ConcatDataset', + ignore_keys=['DATASET_TYPE'], + datasets=[ + dict( + type=dataset_type, + data_root=data_root, + ann_file='ImageSets/Main/train.txt', + data_prefix=dict(img_path='JPEGImages-trainval'), + filter_cfg=dict(filter_empty_gt=True), + pipeline=train_pipeline), + dict( + type=dataset_type, + data_root=data_root, + ann_file='ImageSets/Main/val.txt', + data_prefix=dict(img_path='JPEGImages-trainval'), + filter_cfg=dict(filter_empty_gt=True), + pipeline=train_pipeline, + backend_args=_base_.backend_args) + ])) + +train_cfg = dict(type='EpochBasedTrainLoop', val_interval=12) + +val_dataloader = dict(batch_size=4, num_workers=4) + +val_evaluator = dict(type='DOTAMetric', metric='mAP', iou_thrs=[0.25, 0.5]) + +# default_hooks = dict(logger=dict(type='LoggerHook', interval=30)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/configs/point2rbox/point2rbox-yolof-dota.py b/configs/point2rbox/point2rbox-yolof-dota.py new file mode 100755 index 000000000..43ab984c8 --- /dev/null +++ b/configs/point2rbox/point2rbox-yolof-dota.py @@ -0,0 +1,129 @@ +_base_ = [ + '../_base_/datasets/dota.py', '../_base_/schedules/schedule_1x.py', + '../_base_/default_runtime.py' +] +model = dict( + type='Point2RBoxYOLOF', + crop_size=(1024, 1024), + prob_rot=0.95 * 0.7, + prob_flp=0.05 * 0.7, + sca_fact=0.4, + sca_range=(0.5, 1.5), + basic_pattern='data/basic_patterns/dota', + dense_cls=[4, 5, 6, 9], + use_setrc=False, + use_setsk=True, + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False, + pad_size_divisor=32), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + strides=(1, 2, 2, 1), # DC5 + dilations=(1, 1, 1, 2), + out_indices=(3, ), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='caffe', + init_cfg=dict( + type='Pretrained', + checkpoint='open-mmlab://detectron/resnet50_caffe')), + neck=dict( + type='mmdet.DilatedEncoder', + in_channels=2048, + out_channels=512, + block_mid_channels=128, + num_residual_blocks=4, + block_dilations=[2, 4, 6, 8]), + bbox_head=dict( + type='Point2RBoxYOLOFHead', + num_classes=15, + in_channels=512, + reg_decoded_bbox=True, + num_cls_convs=4, + num_reg_convs=8, + use_objectness=False, + agnostic_cls=[1, 9, 11], + square_cls=[0], + anchor_generator=dict( + type='mmdet.AnchorGenerator', + ratios=[1.0], + scales=[4, 4, 4, 4, 4], + strides=[16]), + bbox_coder=dict( + type='mmdet.DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1., 1., 1., 1.], + add_ctr_clamp=True, + ctr_clamp=16), + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=2.0), + loss_angle=dict(type='mmdet.L1Loss', loss_weight=0.6), + loss_scale_ss=dict(type='mmdet.GIoULoss', loss_weight=0.04)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='Point2RBoxAssigner', + pos_ignore_thr=0.15, + neg_ignore_thr=0.7, + match_times=2), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000)) + +# optimizer +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type='AdamW', + lr=0.00005, + betas=(0.9, 0.999), + weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0., custom_keys={'backbone': dict(lr_mult=1. / 3)})) + +train_pipeline = [ + dict(type='mmdet.LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), + dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), + dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), + dict(type='RBox2Point'), + dict( + type='mmdet.RandomFlip', + prob=0.75, + direction=['horizontal', 'vertical', 'diagonal']), + dict(type='mmdet.RandomShift', prob=0.5, max_shift_px=16), + dict(type='mmdet.PackDetInputs') +] + +train_cfg = dict(type='EpochBasedTrainLoop', val_interval=12) + +train_dataloader = dict( + batch_size=4, num_workers=4, dataset=dict(pipeline=train_pipeline)) + +val_dataloader = dict(batch_size=4, num_workers=4) + +val_evaluator = dict(type='DOTAMetric', metric='mAP', iou_thrs=[0.25, 0.5]) + +# default_hooks = dict(logger=dict(type='LoggerHook', interval=30)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/configs/point2rbox/point2rbox-yolof-hrsc.py b/configs/point2rbox/point2rbox-yolof-hrsc.py new file mode 100755 index 000000000..8dbb7f085 --- /dev/null +++ b/configs/point2rbox/point2rbox-yolof-hrsc.py @@ -0,0 +1,130 @@ +_base_ = [ + '../_base_/datasets/hrsc.py', '../_base_/schedules/schedule_6x.py', + '../_base_/default_runtime.py' +] +model = dict( + type='Point2RBoxYOLOF', + crop_size=(800, 800), + prob_rot=0.95 * 0.7, + prob_flp=0.05 * 0.7, + sca_fact=1.0, + sca_range=(0.5, 1.5), + basic_pattern='data/basic_patterns/hrsc', + dense_cls=[], + use_setrc=True, + use_setsk=True, + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + bgr_to_rgb=False, + pad_size_divisor=32), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + strides=(1, 2, 2, 1), # DC5 + dilations=(1, 1, 1, 2), + out_indices=(3, ), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='caffe', + init_cfg=dict( + type='Pretrained', + checkpoint='open-mmlab://detectron/resnet50_caffe')), + neck=dict( + type='mmdet.DilatedEncoder', + in_channels=2048, + out_channels=512, + block_mid_channels=128, + num_residual_blocks=4, + block_dilations=[2, 4, 6, 8]), + bbox_head=dict( + type='Point2RBoxYOLOFHead', + num_classes=1, + in_channels=512, + reg_decoded_bbox=True, + num_cls_convs=4, + num_reg_convs=8, + use_objectness=False, + agnostic_cls=[], + square_cls=[], + anchor_generator=dict( + type='mmdet.AnchorGenerator', + ratios=[1.0], + scales=[8], + strides=[16]), + bbox_coder=dict( + type='mmdet.DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1., 1., 1., 1.], + add_ctr_clamp=True, + ctr_clamp=16), + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0), + loss_angle=dict(type='mmdet.L1Loss', loss_weight=0.3), + loss_scale_ss=dict(type='mmdet.GIoULoss', loss_weight=0.02)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='Point2RBoxAssigner', + pos_ignore_thr=0.15, + neg_ignore_thr=0.7, + match_times=4), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000)) + +# optimizer +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type='AdamW', + lr=0.00005, + betas=(0.9, 0.999), + weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0., custom_keys={'backbone': dict(lr_mult=1. / 3)})) + +train_pipeline = [ + dict(type='mmdet.LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), + dict(type='mmdet.FixShapeResize', width=800, height=800, keep_ratio=True), + dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), + dict(type='RBox2Point'), + dict( + type='mmdet.RandomFlip', + prob=0.75, + direction=['horizontal', 'vertical', 'diagonal']), + dict(type='RandomRotate', prob=1, angle_range=180), + dict(type='mmdet.RandomShift', prob=0.5, max_shift_px=16), + dict(type='mmdet.PackDetInputs') +] + +train_cfg = dict(type='EpochBasedTrainLoop', val_interval=12) + +train_dataloader = dict( + batch_size=4, num_workers=4, dataset=dict(pipeline=train_pipeline)) + +val_dataloader = dict(batch_size=4, num_workers=4) + +val_evaluator = dict(type='DOTAMetric', metric='mAP', iou_thrs=[0.25, 0.5]) + +# default_hooks = dict(logger=dict(type='LoggerHook', interval=30)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/mmrotate/datasets/transforms/__init__.py b/mmrotate/datasets/transforms/__init__.py index ee0801fcd..f5dc98770 100644 --- a/mmrotate/datasets/transforms/__init__.py +++ b/mmrotate/datasets/transforms/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .loading import LoadPatchFromNDArray from .transforms import (ConvertBoxType, ConvertMask2BoxType, - RandomChoiceRotate, RandomRotate, Rotate) + RandomChoiceRotate, RandomRotate, RBox2Point, Rotate) __all__ = [ 'LoadPatchFromNDArray', 'Rotate', 'RandomRotate', 'RandomChoiceRotate', - 'ConvertBoxType', 'ConvertMask2BoxType' + 'ConvertBoxType', 'RBox2Point', 'ConvertMask2BoxType' ] diff --git a/mmrotate/datasets/transforms/transforms.py b/mmrotate/datasets/transforms/transforms.py index 141c42569..7ea77d833 100644 --- a/mmrotate/datasets/transforms/transforms.py +++ b/mmrotate/datasets/transforms/transforms.py @@ -43,6 +43,23 @@ def __repr__(self): return repr_str +@TRANSFORMS.register_module() +class RBox2Point(BaseTransform): + """Convert RBoxes to Single Center Points.""" + + def __init__(self) -> None: + pass + + def transform(self, results: dict) -> dict: + """The transform function.""" + + results['gt_bboxes'].tensor[:, 2] = 0.1 + results['gt_bboxes'].tensor[:, 3] = 0.1 + results['gt_bboxes'].tensor[:, 4] = 0 + + return results + + @TRANSFORMS.register_module() class Rotate(BaseTransform): """Rotate the images, bboxes, masks and segmentation map by a certain diff --git a/mmrotate/models/dense_heads/__init__.py b/mmrotate/models/dense_heads/__init__.py index e6119ccda..2702bfec3 100644 --- a/mmrotate/models/dense_heads/__init__.py +++ b/mmrotate/models/dense_heads/__init__.py @@ -5,6 +5,7 @@ from .h2rbox_v2_head import H2RBoxV2Head from .oriented_reppoints_head import OrientedRepPointsHead from .oriented_rpn_head import OrientedRPNHead +from .point2rbox_yolof_head import Point2RBoxYOLOFHead from .r3_head import R3Head, R3RefineHead from .rotated_atss_head import RotatedATSSHead from .rotated_fcos_head import RotatedFCOSHead @@ -19,5 +20,5 @@ 'SAMRepPointsHead', 'AngleBranchRetinaHead', 'RotatedATSSHead', 'RotatedFCOSHead', 'OrientedRepPointsHead', 'R3Head', 'R3RefineHead', 'S2AHead', 'S2ARefineHead', 'CFAHead', 'H2RBoxHead', 'H2RBoxV2Head', - 'RotatedRTMDetHead', 'RotatedRTMDetSepBNHead' + 'RotatedRTMDetHead', 'RotatedRTMDetSepBNHead', 'Point2RBoxYOLOFHead' ] diff --git a/mmrotate/models/dense_heads/point2rbox_yolof_head.py b/mmrotate/models/dense_heads/point2rbox_yolof_head.py new file mode 100755 index 000000000..419008865 --- /dev/null +++ b/mmrotate/models/dense_heads/point2rbox_yolof_head.py @@ -0,0 +1,943 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, is_norm +from mmdet.models.dense_heads import YOLOFHead +from mmdet.models.task_modules.prior_generators import anchor_inside_flags +from mmdet.models.utils import (filter_scores_and_topk, levels_to_images, + multi_apply, unmap) +from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, + cat_boxes) +from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean +from mmengine.config import ConfigDict +from mmengine.model import bias_init_with_prob, constant_init, normal_init +from mmengine.structures import InstanceData +from torch import Tensor + +from mmrotate.registry import MODELS, TASK_UTILS +from mmrotate.structures import RotatedBoxes + +INF = 1e8 + + +@MODELS.register_module() +class Point2RBoxYOLOFHead(YOLOFHead): + """Detection Head of `Point2RBox `_ + + Args: + num_classes (int): The number of object classes (w/o background) + in_channels (list[int]): The number of input channels per scale. + cls_num_convs (int): The number of convolutions of cls branch. + Defaults to 2. + reg_num_convs (int): The number of convolutions of reg branch. + Defaults to 4. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to ``dict(type='BN', requires_grad=True)``. + """ + + def __init__(self, + num_classes: int, + in_channels: List[int], + num_cls_convs: int = 2, + num_reg_convs: int = 4, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + use_bbox_hdr: bool = False, + use_transform_ss: bool = True, + use_objectness: bool = True, + full_supervised: bool = False, + agnostic_cls: list = [1, 9, 11], + square_cls: list = [0], + synthetic_pos_weight: float = 0.1, + loss_point: ConfigType = dict( + type='mmdet.L1Loss', loss_weight=0.1), + angle_coder: ConfigType = dict( + type='PSCCoder', + angle_version='le90', + dual_freq=False, + thr_mod=0), + loss_angle: ConfigType = dict( + type='mmdet.L1Loss', loss_weight=0.1), + loss_ratio: ConfigType = dict( + type='mmdet.L1Loss', loss_weight=1.0), + loss_symmetry_ss: ConfigType = dict( + type='mmdet.SmoothL1Loss', loss_weight=1.0, beta=0.1), + loss_scale_ss: ConfigType = dict( + type='mmdet.GIoULoss', loss_weight=0.05), + **kwargs) -> None: + self.num_cls_convs = num_cls_convs + self.num_reg_convs = num_reg_convs + self.norm_cfg = norm_cfg + self.use_bbox_hdr = use_bbox_hdr + self.use_transform_ss = use_transform_ss + self.use_objectness = use_objectness + self.full_supervised = full_supervised + self.agnostic_cls = agnostic_cls + self.square_cls = square_cls + self.synthetic_pos_weight = synthetic_pos_weight + self.angle_coder = TASK_UTILS.build(angle_coder) + super().__init__( + num_classes=num_classes, in_channels=in_channels, **kwargs) + self.loss_point = MODELS.build(loss_point) + self.loss_angle = MODELS.build(loss_angle) + self.loss_ratio = MODELS.build(loss_ratio) + self.loss_symmetry_ss = MODELS.build(loss_symmetry_ss) + self.loss_scale_ss = MODELS.build(loss_scale_ss) + + def _init_layers(self) -> None: + cls_subnet = [] + bbox_subnet = [] + ang_subnet = [] + for i in range(self.num_cls_convs): + cls_subnet.append( + ConvModule( + self.in_channels, + self.in_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg)) + for i in range(self.num_reg_convs): + bbox_subnet.append( + ConvModule( + self.in_channels, + self.in_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg)) + for i in range(self.num_reg_convs): + ang_subnet.append( + ConvModule( + self.in_channels, + self.in_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg)) + self.cls_subnet = nn.Sequential(*cls_subnet) + self.bbox_subnet = nn.Sequential(*bbox_subnet) + self.ang_subnet = nn.Sequential(*ang_subnet) + self.cls_score = nn.Conv2d( + self.in_channels, + self.num_base_priors * self.num_classes, + kernel_size=3, + stride=1, + padding=1) + self.cls_score_f = nn.Conv2d( + self.in_channels, + self.num_base_priors * self.num_classes, + kernel_size=3, + stride=1, + padding=1) + self.bbox_cent_pred = nn.Conv2d( + self.in_channels, + self.num_base_priors * 2, # CenterXY + kernel_size=3, + stride=1, + padding=1) + self.bbox_size_pred = nn.Conv2d( + self.in_channels, + self.num_base_priors * 2 * + (5 if self.use_bbox_hdr else 1), # SizeXY + kernel_size=3, + stride=1, + padding=1) + if self.use_objectness: + self.object_pred = nn.Conv2d( + self.in_channels, + self.num_base_priors, + kernel_size=3, + stride=1, + padding=1) + self.angle_pred = nn.Conv2d( + self.in_channels, + self.num_base_priors * self.angle_coder.encode_size, + kernel_size=3, + stride=1, + padding=1) + self.ratio_pred = nn.Conv2d( + self.in_channels, + self.num_base_priors, + kernel_size=3, + stride=1, + padding=1) + + def init_weights(self) -> None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + + # Use prior in model initialization to improve stability + bias_cls = bias_init_with_prob(0.01) + torch.nn.init.constant_(self.cls_score.bias, bias_cls) + torch.nn.init.constant_(self.cls_score_f.bias, bias_cls) + + def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + + Returns: + tuple: + normalized_cls_score (Tensor): Normalized Cls scores for a \ + single scale level, the channels number is \ + num_base_priors * num_classes. + bbox_reg (Tensor): Box energies / deltas for a single scale \ + level, the channels number is num_base_priors * 4. + """ + cls_feat = self.cls_subnet(x) + cls_score = self.cls_score(cls_feat) + cls_score_f = self.cls_score_f(cls_feat) + N, _, H, W = cls_score.shape + cls_score = cls_score.view(N, -1, self.num_classes, H, W) + + reg_feat = self.bbox_subnet(x) + ang_feat = self.ang_subnet(x) + + if self.use_bbox_hdr: + r = self.bbox_size_pred(reg_feat).sigmoid() * 4 - 2 + r = r.view(r.shape[0], 5, -1, *r.shape[2:]) + w = torch.softmax(-r.abs(), 1) + o = r.new_tensor((-2, -1, 0, 1, 2))[None, :, None, None, None] + bbox_size_reg = (w * (r + o)).sum(1) * 2 + else: + bbox_size_reg = self.bbox_size_pred(reg_feat) + + bbox_cent_reg = self.bbox_cent_pred(reg_feat) + angle_reg = self.angle_pred(ang_feat) + ratio_reg = self.ratio_pred(ang_feat).sigmoid() + + bbox_reg = (bbox_cent_reg, bbox_size_reg, ratio_reg, angle_reg) + bbox_reg = [ + x.view(N, self.num_base_priors, -1, H, W) for x in bbox_reg + ] + bbox_reg = torch.cat(bbox_reg, 2).view(N, -1, H, W) + + # implicit objectness + if self.use_objectness: + objectness = self.object_pred(reg_feat) + objectness = objectness.view(N, -1, 1, H, W) + normalized_cls_score = cls_score + objectness - torch.log( + 1. + torch.clamp(cls_score.exp(), max=INF) + + torch.clamp(objectness.exp(), max=INF)) + normalized_cls_score = normalized_cls_score.view(N, -1, H, W) + else: + normalized_cls_score = cls_score.view(N, -1, H, W) + + return normalized_cls_score, cls_score_f, bbox_reg + + def obb2xyxy(self, obb): + w = obb[:, 2::5] + h = obb[:, 3::5] + a = obb[:, 4::5] + cosa = torch.cos(a).abs() + sina = torch.sin(a).abs() + hbbox_w = cosa * w + sina * h + hbbox_h = sina * w + cosa * h + dx = obb[..., 0] + dy = obb[..., 1] + dw = hbbox_w.reshape(-1) + dh = hbbox_h.reshape(-1) + x1 = dx - dw / 2 + y1 = dy - dh / 2 + x2 = dx + dw / 2 + y2 = dy + dh / 2 + return torch.stack((x1, y1, x2, y2), -1) + + def loss_by_feat( + self, + cls_scores: List[Tensor], + cls_scores_f: List[Tensor], + bbox_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + assert len(cls_scores) == 1 + assert self.prior_generator.num_levels == 1 + + device = cls_scores[0].device + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + # The output level is always 1 + anchor_list = [anchors[0] for anchors in anchor_list] + valid_flag_list = [valid_flags[0] for valid_flags in valid_flag_list] + + # cls_scores_list = levels_to_images(cls_scores) + cls_scores_f_list = levels_to_images(cls_scores_f) + bbox_preds_list = levels_to_images(bbox_preds) + + cls_reg_targets = self.get_targets( + cls_scores_f_list, + bbox_preds_list, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + if cls_reg_targets is None: + return None + (batch_labels, batch_label_weights, avg_factor, bbox_weights, + point_weights, pos_pred_boxes, target_boxes, target_labels, + target_bids) = cls_reg_targets + + flatten_labels = batch_labels.reshape(-1) + batch_label_weights = batch_label_weights.reshape(-1) + cls_score = cls_scores[0].permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + cls_score_f = cls_scores_f[0].permute(0, 2, 3, 1).reshape( + -1, self.cls_out_channels) + + point_f_mask = batch_label_weights == self.synthetic_pos_weight + label_weight = batch_label_weights.clone() + label_weight[point_f_mask] = 0 + avg_factor_point = max(avg_factor - point_f_mask.sum().item(), 1) + avg_factor_point = reduce_mean( + torch.tensor(avg_factor_point, dtype=torch.float, + device=device)).item() + + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + + # classification loss + if self.full_supervised: + loss_cls = self.loss_cls( + cls_score, + flatten_labels, + batch_label_weights, + avg_factor=avg_factor) + loss_cls_f = 0 * loss_cls + else: + loss_cls = self.loss_cls( + cls_score, + flatten_labels, + label_weight, + avg_factor=avg_factor_point) + loss_cls_f = self.loss_cls( + cls_score_f, + flatten_labels, + batch_label_weights, + avg_factor=avg_factor) + + # regression loss + if pos_pred_boxes.shape[0] == 0: + # no pos sample + loss_bbox = pos_pred_boxes.sum() * 0 + loss_angle = pos_pred_boxes.sum() * 0 + loss_ratio = pos_pred_boxes.sum() * 0 + loss_point = pos_pred_boxes.sum() * 0 + loss_scale_ss = pos_pred_boxes.sum() * 0 + loss_symmetry_ss = pos_pred_boxes.sum() * 0 + else: + target_boxes = RotatedBoxes(target_boxes).regularize_boxes('le90') + if self.agnostic_cls: + agnostic_mask = torch.stack([ + target_labels == c for c in self.agnostic_cls + ]).sum(0).bool() + else: + agnostic_mask = target_labels < 0 + target_boxes[agnostic_mask, 4] = 0 + if self.square_cls: + square_mask = torch.stack([ + target_labels == c for c in self.square_cls + ]).sum(0).bool() + else: + square_mask = target_labels < 0 + target_boxes[square_mask, 4] = 0 + + pos_pred_xyxy = pos_pred_boxes[:, :4] + target_xyxy = self.obb2xyxy(target_boxes) + loss_bbox = self.loss_bbox( + pos_pred_xyxy, + target_xyxy, + bbox_weights.float(), + avg_factor=bbox_weights.sum()) + + angle_weights = bbox_weights.clone().float() + angle_weights[agnostic_mask] = 0 + angle_weights[square_mask] = 0 + pos_pred_angle = pos_pred_boxes[:, 5:] + target_angle = self.angle_coder.encode(target_boxes[:, 4:]) + loss_angle = self.loss_angle( + pos_pred_angle, + target_angle, + angle_weights[:, None], + avg_factor=angle_weights.sum()) + + pos_pred_ratio = pos_pred_boxes[:, 4] + target_ratio = target_boxes[:, 3] / (target_boxes[:, 2] + 1e-5) + loss_ratio = self.loss_ratio( + pos_pred_ratio, + target_ratio, + angle_weights.float(), + avg_factor=angle_weights.sum()) + + pos_pred_cxcywh = bbox_xyxy_to_cxcywh(pos_pred_boxes[:, :4]) + pos_pred_cen = pos_pred_cxcywh[:, 0:2] + target_cen = target_boxes[:, 0:2] + point_valid = (pos_pred_cen - target_cen).abs().sum(1) < 32 + point_weights *= point_valid + loss_point = self.loss_point( + pos_pred_cen / 16, + target_cen / 16, + point_weights.float()[:, None], + avg_factor=point_weights.sum()) + + if self.use_transform_ss and loss_bbox.item( + ) < 0.5 and loss_angle.item() < 0.2: + # Self-supervision + # Calculate SS only for point annotations + target_bids[~point_weights] = -1 + # print(f'{target_bids[point_weights] = }') + + # Aggregate the same bbox based on their identical bid + bid, idx = torch.unique(target_bids, return_inverse=True) + pair_bid_targets = torch.empty_like(bid).index_reduce_( + 0, idx, target_bids, 'mean', include_self=False) + + # Generate a mask to eliminate bboxes without correspondence + # (bcnt is supposed to be 3, for ori, rot, and flp) + _, bidx, bcnt = torch.unique( + pair_bid_targets.long(), + return_inverse=True, + return_counts=True) + bmsk = bcnt[bidx] == 2 + + # print(pair_bid_targets) + b_sca = (pair_bid_targets % 1 > 0.7).sum() > 0 + + # The reduce all sample points of each object + pair_box_target = torch.empty_like(bid).index_reduce_( + 0, idx, target_boxes[:, 2], 'mean', + include_self=False)[bmsk].view(-1, 2) + pair_box_preds = torch.empty( + *bid.shape, pos_pred_cxcywh.shape[-1], + device=bid.device).index_reduce_( + 0, idx, pos_pred_cxcywh, 'mean', + include_self=False)[bmsk].view( + -1, 2, pos_pred_cxcywh.shape[-1]) + + ori_box = pair_box_preds[:, 0] + trs_box = pair_box_preds[:, 1] + + if b_sca: + sca = (pair_box_target[:, 1] / + pair_box_target[:, 0]).mean() + ori_box *= sca + + # Must limit the center and size range in ss + ss_weight_cen = (ori_box[:, :2] - + trs_box[:, :2]).abs().sum(1) < 32 + ss_weight_wh0 = (ori_box[:, 2:] + + trs_box[:, 2:]).sum(1) > 12 * 4 + ss_weight_wh1 = (ori_box[:, 2:] + + trs_box[:, 2:]).sum(1) < 512 * 4 + ss_weight = ss_weight_cen * ss_weight_wh0 * ss_weight_wh1 + if len(ori_box): + loss_scale_ss = self.loss_scale_ss( + bbox_cxcywh_to_xyxy(ori_box), + bbox_cxcywh_to_xyxy(trs_box), ss_weight) + else: + loss_scale_ss = pos_pred_cxcywh.sum() * 0 + loss_symmetry_ss = pos_pred_angle.sum() * 0 + else: + b_flp = (pair_bid_targets % 1 > 0.5).sum() > 0 + + # The reduce all sample points of each object + pair_angle_targets = torch.empty_like(bid).index_reduce_( + 0, idx, target_boxes[:, 4], 'mean', + include_self=False)[bmsk].view(-1, 2) + pair_angle_preds = torch.empty( + *bid.shape, + pos_pred_angle.shape[-1], + device=bid.device).index_reduce_( + 0, idx, pos_pred_angle, 'mean', + include_self=False)[bmsk].view( + -1, 2, pos_pred_angle.shape[-1]) + + pair_angle_preds = self.angle_coder.decode( + pair_angle_preds, keepdim=False) + + # Eliminate invalid pairs + img_shape = batch_img_metas[0]['img_shape'] + if b_flp: + flp_box = ori_box[:, :2] + flp_box[:, 1] = img_shape[0] - flp_box[:, 1] + ss_weight = (flp_box[:, :2] - + trs_box[:, :2]).abs().sum(1) < 32 + d_ang = pair_angle_preds[:, 0] + pair_angle_preds[:, 1] + else: + a = pair_angle_targets[:, 0] - pair_angle_targets[:, 1] + cosa = torch.cos(a) + sina = torch.sin(a) + m = torch.stack((cosa, sina, -sina, cosa), + -1).view(-1, 2, 2) + rot_box = torch.bmm( + m, (ori_box[:, :2] - img_shape[0] / + 2)[..., None])[:, :, 0] + img_shape[0] / 2 + ss_weight = (rot_box[:, :2] - + trs_box[:, :2]).abs().sum(1) < 32 + d_ang = (pair_angle_preds[:, 0] - + pair_angle_preds[:, 1]) - ( + pair_angle_targets[:, 0] - + pair_angle_targets[:, 1]) + + # Eliminate agnostic objects + if self.agnostic_cls: + pair_labels = torch.empty( + bid.shape, + dtype=target_labels.dtype, + device=bid.device).index_reduce_( + 0, + idx, + target_labels, + 'mean', + include_self=False)[bmsk].view(-1, 2)[:, 0] + pair_agnostic_mask = torch.stack([ + pair_labels == c for c in self.agnostic_cls + ]).sum(0).bool() + ss_weight[pair_agnostic_mask] = 0 + + d_ang = (d_ang + torch.pi / 2) % torch.pi - torch.pi / 2 + + loss_scale_ss = pos_pred_cxcywh.sum() * 0 + if len(d_ang): + loss_symmetry_ss = self.loss_symmetry_ss( + d_ang, torch.zeros_like(d_ang), ss_weight) + else: + loss_symmetry_ss = pos_pred_angle.sum() * 0 + else: + loss_scale_ss = pos_pred_cxcywh.sum() * 0 + loss_symmetry_ss = pos_pred_angle.sum() * 0 + + return dict( + loss_cls=loss_cls, + loss_cls_f=loss_cls_f, + loss_bbox=loss_bbox, + loss_angle=loss_angle, + loss_ratio=loss_ratio, + loss_point=loss_point, + loss_scale_ss=loss_scale_ss, + loss_symmetry_ss=loss_symmetry_ss) + + def get_targets(self, + cls_scores_list: List[Tensor], + bbox_preds_list: List[Tensor], + anchor_list: List[Tensor], + valid_flag_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + cls_scores_list (list[Tensor]): Classification scores of + each image. each is a 4D-tensor, the shape is + (h * w, num_anchors * num_classes). + bbox_preds_list (list[Tensor]): Bbox preds of each image. + each is a 4D-tensor, the shape is (h * w, num_anchors * 4). + anchor_list (list[Tensor]): Anchors of each image. Each element of + is a tensor of shape (h * w * num_anchors, 4). + valid_flag_list (list[Tensor]): Valid flags of each image. Each + element of is a tensor of shape (h * w * num_anchors, ) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - batch_labels (Tensor): Label of all images. Each element \ + of is a tensor of shape (batch, h * w * num_anchors) + - batch_label_weights (Tensor): Label weights of all images \ + of is a tensor of shape (batch, h * w * num_anchors) + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + results = multi_apply( + self._get_targets_single, + cls_scores_list, + bbox_preds_list, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + (all_labels, all_label_weights, pos_inds, neg_inds, + sampling_results_list) = results[:5] + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + rest_results = list(results[5:]) # user-added return values + + batch_labels = torch.stack(all_labels, 0) + batch_label_weights = torch.stack(all_label_weights, 0) + + res = (batch_labels, batch_label_weights, avg_factor) + for i, rests in enumerate(rest_results): # user-added return values + rest_results[i] = torch.cat(rests, 0) + + return res + tuple(rest_results) + + def _get_targets_single(self, + cls_scores: Tensor, + bbox_preds: Tensor, + flat_anchors: Tensor, + valid_flags: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression and classification targets for anchors in a + single image. + + Args: + bbox_preds (Tensor): Bbox prediction of the image, which + shape is (h * w ,4) + flat_anchors (Tensor): Anchors of the image, which shape is + (h * w * num_anchors ,4) + valid_flags (Tensor): Valid flags of the image, which shape is + (h * w * num_anchors,). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: + labels (Tensor): Labels of image, which shape is + (h * w * num_anchors, ). + label_weights (Tensor): Label weights of image, which shape is + (h * w * num_anchors, ). + pos_inds (Tensor): Pos index of image. + neg_inds (Tensor): Neg index of image. + sampling_result (obj:`SamplingResult`): Sampling result. + pos_bbox_weights (Tensor): The Weight of using to calculate + the bbox branch loss, which shape is (num, ). + pos_predicted_boxes (Tensor): boxes predicted value of + using to calculate the bbox branch loss, which shape is + (num, 4). + pos_target_boxes (Tensor): boxes target value of + using to calculate the bbox branch loss, which shape is + (num, 4). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + dim = self.bbox_coder.encode_size + self.angle_coder.encode_size + 1 + bbox_preds = bbox_preds.reshape(-1, dim) + bbox_preds = bbox_preds[inside_flags, :] + + ### + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + cls_scores = cls_scores[inside_flags, :] + + # decoded bbox + decoder_bbox_preds = self.bbox_coder.decode(anchors, bbox_preds[:, :4]) + decoder_bbox_preds = torch.cat((decoder_bbox_preds, bbox_preds[:, 4:]), + -1) + pred_instances = InstanceData( + priors=anchors, + decoder_priors=decoder_bbox_preds, + cls_scores=cls_scores) + assign_result = self.assigner.assign(pred_instances, gt_instances, + gt_instances_ignore) + + pos_point_index = assign_result.get_extra_property('pos_point_index') + pos_bbox_weights = assign_result.get_extra_property('pos_bbox_mask') + pos_point_weights = assign_result.get_extra_property('pos_point_mask') + pos_predicted_boxes = assign_result.get_extra_property( + 'pos_predicted_boxes') + pos_target_boxes = assign_result.get_extra_property('target_boxes') + pos_target_labels = assign_result.get_extra_property('target_labels') + pos_target_bids = assign_result.get_extra_property('target_bids') + + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + num_valid_anchors = anchors.shape[0] + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = self.synthetic_pos_weight + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + if pos_point_index is not None: + label_weights[pos_point_index.reshape(-1)] = 1 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap( + labels, num_total_anchors, inside_flags, + fill=self.num_classes) # fill bg label + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + + return (labels, label_weights, pos_inds, neg_inds, sampling_result, + pos_bbox_weights, pos_point_weights, pos_predicted_boxes, + pos_target_boxes, pos_target_labels, pos_target_bids) + + def predict_by_feat(self, + cls_scores: List[Tensor], + cls_scores_f: List[Tensor], + bbox_preds: List[Tensor], + score_factors: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + + result_list = super().predict_by_feat(cls_scores, bbox_preds, + score_factors, batch_img_metas, + cfg, rescale, with_nms) + + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if score_factor_list[0] is None: + # e.g. Retina, FreeAnchor, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, etc. + with_score_factors = True + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_labels = [] + if with_score_factors: + mlvl_score_factors = [] + else: + mlvl_score_factors = None + for level_idx, (cls_score, bbox_pred, score_factor, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, + score_factor_list, mlvl_priors)): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + dim = ( + self.bbox_coder.encode_size + self.angle_coder.encode_size + 1) + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) + if with_score_factors: + score_factor = score_factor.permute(1, 2, + 0).reshape(-1).sigmoid() + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + scores = cls_score.softmax(-1)[:, :-1] + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + score_thr = cfg.get('score_thr', 0) + + results = filter_scores_and_topk( + scores, score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + scores, labels, keep_idxs, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + + if with_score_factors: + score_factor = score_factor[keep_idxs] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + if with_score_factors: + mlvl_score_factors.append(score_factor) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = cat_boxes(mlvl_valid_priors) + bboxes = self.bbox_coder.decode( + priors, bbox_pred[:, :4], max_shape=img_shape) + bboxes = bbox_xyxy_to_cxcywh(bboxes) + ratios = bbox_pred[:, 4:5].clamp(0.05, 1) + angles = self.angle_coder.decode(bbox_pred[:, 5:], keepdim=True) + + labels = torch.cat(mlvl_labels) + if self.agnostic_cls: + agnostic_mask = torch.stack( + [labels == c for c in self.agnostic_cls]).sum(0).bool() + else: + agnostic_mask = labels < 0 + if self.square_cls: + square_mask = torch.stack([labels == c for c in self.square_cls + ]).sum(0).bool() + else: + square_mask = labels < 0 + angles[agnostic_mask] = 0 + ratios[agnostic_mask] = 1 + ratios[square_mask] = 1 + + cosa = torch.cos(angles).abs() + sina = torch.sin(angles).abs() + m = torch.stack( + (ratios, -torch.ones_like(ratios), cosa, sina, sina, cosa), + -1).view(-1, 3, 2) + b = torch.cat((torch.zeros_like(bboxes[:, 2:3]), bboxes[:, 2:4]), + 1)[..., None] + wh = torch.linalg.lstsq(m, b).solution[:, :, 0] + wh[square_mask] *= 1.4 + + # For DIOR + if self.cls_out_channels == 20: + wh[labels == 0] *= 0.8 + wh[labels == 17] *= 1.2 + wh[labels == 13] *= 0.7 + wh[labels == 18] *= 0.7 + wh[labels == 19] *= 0.7 + + bboxes = torch.cat((bboxes[:, 0:2], wh, angles), 1) + + results = InstanceData() + results.bboxes = RotatedBoxes(bboxes) + results.scores = torch.cat(mlvl_scores) + results.labels = labels + if with_score_factors: + results.score_factors = torch.cat(mlvl_score_factors) + + results = self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + return results diff --git a/mmrotate/models/detectors/__init__.py b/mmrotate/models/detectors/__init__.py index cbd025b5c..44a5017d7 100644 --- a/mmrotate/models/detectors/__init__.py +++ b/mmrotate/models/detectors/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .h2rbox import H2RBoxDetector from .h2rbox_v2 import H2RBoxV2Detector +from .point2rbox_yolof import Point2RBoxYOLOF from .refine_single_stage import RefineSingleStageDetector -__all__ = ['RefineSingleStageDetector', 'H2RBoxDetector', 'H2RBoxV2Detector'] +__all__ = [ + 'RefineSingleStageDetector', 'H2RBoxDetector', 'H2RBoxV2Detector', + 'Point2RBoxYOLOF' +] diff --git a/mmrotate/models/detectors/point2rbox_yolof.py b/mmrotate/models/detectors/point2rbox_yolof.py new file mode 100755 index 000000000..5f094b0b6 --- /dev/null +++ b/mmrotate/models/detectors/point2rbox_yolof.py @@ -0,0 +1,345 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from typing import Tuple, Union + +import torch +from mmdet.models.detectors.single_stage import SingleStageDetector +from mmdet.models.utils import unpack_gt_instances +from mmdet.structures import DetDataSample, SampleList +from mmdet.structures.bbox import get_box_tensor +from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig +from mmengine.structures import InstanceData +from torch import Tensor +from torch.nn.functional import grid_sample +from torchvision import transforms + +from mmrotate.models.task_modules.synthesis_generators import \ + point2rbox_generator +from mmrotate.registry import MODELS +from mmrotate.structures.bbox import RotatedBoxes + + +@MODELS.register_module() +class Point2RBoxYOLOF(SingleStageDetector): + r"""Implementation of `Point2RBox + `_ + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone module. + neck (:obj:`ConfigDict` or dict): The neck module. + bbox_head (:obj:`ConfigDict` or dict): The bbox head module. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of YOLOF. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of YOLOF. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): + Model preprocessing config for processing the input data. + it usually includes ``to_rgb``, ``pad_size_divisor``, + ``pad_value``, ``mean`` and ``std``. Defaults to None. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + crop_size: Tuple[int, int] = (1024, 1024), + padding: str = 'reflection', + rot_range: Tuple[float, float] = (0.25, 0.75), + sca_range: Tuple[float, float] = (0.5, 1.5), + sca_fact: float = 1.0, + prob_rot: float = 0.475, + prob_flp: float = 0.025, + basic_pattern: str = 'data/dota', + dense_cls: list = [], + use_synthesis: bool = True, + use_setrc: bool = True, + use_setsk: bool = True, + debug: bool = False, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + self.crop_size = crop_size + self.padding = padding + self.rot_range = rot_range + self.sca_range = sca_range + self.sca_fact = sca_fact + self.prob_rot = prob_rot + self.prob_flp = prob_flp + self.basic_pattern = basic_pattern + self.dense_cls = dense_cls + self.use_synthesis = use_synthesis + self.debug = debug + self.basic_pattern = point2rbox_generator.load_basic_pattern( + self.basic_pattern, use_setrc, use_setsk) + + def rotate_crop( + self, + batch_inputs: Tensor, + rot: float = 0., + size: Tuple[int, int] = (768, 768), + batch_gt_instances: InstanceList = None, + padding: str = 'reflection') -> Tuple[Tensor, InstanceList]: + """ + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + rot (float): Angle of view rotation. Defaults to 0. + size (tuple[int]): Crop size from image center. + Defaults to (768, 768). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + padding (str): Padding method of image black edge. + Defaults to 'reflection'. + + Returns: + Processed batch_inputs (Tensor) and batch_gt_instances + (list[:obj:`InstanceData`]) + """ + device = batch_inputs.device + n, c, h, w = batch_inputs.shape + size_h, size_w = size + crop_h = (h - size_h) // 2 + crop_w = (w - size_w) // 2 + if rot != 0: + cosa, sina = math.cos(rot), math.sin(rot) + tf = batch_inputs.new_tensor([[cosa, -sina], [sina, cosa]], + dtype=torch.float) + x_range = torch.linspace(-1, 1, w, device=device) + y_range = torch.linspace(-1, 1, h, device=device) + y, x = torch.meshgrid(y_range, x_range) + grid = torch.stack([x, y], -1).expand([n, -1, -1, -1]) + grid = grid.reshape(-1, 2).matmul(tf).view(n, h, w, 2) + # rotate + batch_inputs = grid_sample( + batch_inputs, grid, 'bilinear', padding, align_corners=True) + if batch_gt_instances is not None: + for i, gt_instances in enumerate(batch_gt_instances): + gt_bboxes = get_box_tensor(gt_instances.bboxes) + xy, wh, a = gt_bboxes[..., :2], gt_bboxes[ + ..., 2:4], gt_bboxes[..., [4]] + ctr = tf.new_tensor([[w / 2, h / 2]]) + xy = (xy - ctr).matmul(tf.T) + ctr + a = a + rot + rot_gt_bboxes = torch.cat([xy, wh, a], dim=-1) + batch_gt_instances[i].bboxes = RotatedBoxes(rot_gt_bboxes) + batch_inputs = batch_inputs[..., crop_h:crop_h + size_h, + crop_w:crop_w + size_w] + if batch_gt_instances is None: + return batch_inputs + else: + for i, gt_instances in enumerate(batch_gt_instances): + gt_bboxes = get_box_tensor(gt_instances.bboxes) + xy, wh, a = gt_bboxes[..., :2], gt_bboxes[..., + 2:4], gt_bboxes[..., + [4]] + xy = xy - xy.new_tensor([[crop_w, crop_h]]) + crop_gt_bboxes = torch.cat([xy, wh, a], dim=-1) + batch_gt_instances[i].bboxes = RotatedBoxes(crop_gt_bboxes) + + return batch_inputs, batch_gt_instances + + def add_synthesis(self, batch_inputs, batch_gt_instances): + + def synthesis_single(img, bboxes, labels): + labels = labels[:, None] + bb = torch.cat((bboxes, torch.ones_like(labels), labels), -1) + img, bb = point2rbox_generator.generate_sythesis( + img, bb, self.sca_fact, *self.basic_pattern, self.dense_cls, + self.crop_size[0]) + instance_data = InstanceData() + instance_data.labels = bb[:, 6].long() + instance_data.bboxes = bb[:, :5] + return img, instance_data + + p = ((synthesis_single)(img, gt.bboxes.cpu(), gt.labels.cpu()) + for (img, gt) in zip(batch_inputs.cpu(), batch_gt_instances)) + + img, instance_data = zip(*p) + batch_inputs = torch.stack(img, 0).to(batch_inputs) + instance_data = list(instance_data) + for i, gt in enumerate(instance_data): + gt.labels = gt.labels.to(batch_gt_instances[i].labels) + gt.bboxes = gt.bboxes.to(batch_gt_instances[i].bboxes) + batch_gt_instances[i] = InstanceData.cat( + [batch_gt_instances[i], gt]) + + return batch_inputs, batch_gt_instances + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances, _, _ = unpack_gt_instances(batch_data_samples) + + # Generate synthetic objects + + if self.use_synthesis: + batch_inputs, batch_gt_instances = self.add_synthesis( + batch_inputs, batch_gt_instances) + + if self.bbox_head.use_transform_ss: + # Crop original images and gts + batch_inputs, batch_gt_instances = self.rotate_crop( + batch_inputs, 0, self.crop_size, batch_gt_instances, + self.padding) + offset = 1 + for gt_instances in batch_gt_instances: + gt_instances.bids = torch.arange( + 0, + len(gt_instances.bboxes), + 1, + device=gt_instances.bboxes.device) + offset + 0.2 + offset += len(gt_instances.bboxes) + + # Concat original/rotated/flipped images and gts + p = torch.rand(1) + if p < self.prob_rot: # rot + # Generate rotated images and gts + rot = math.pi * ( + torch.rand(1, device=batch_inputs.device) * + (self.rot_range[1] - self.rot_range[0]) + + self.rot_range[0]) + batch_gt_rot = copy.deepcopy(batch_gt_instances) + batch_inputs_rot, batch_gt_rot = self.rotate_crop( + batch_inputs, rot, self.crop_size, batch_gt_rot, + self.padding) + offset = 1 + for gt_instances in batch_gt_rot: + gt_instances.bids = torch.arange( + 0, + len(gt_instances.bboxes), + 1, + device=gt_instances.bboxes.device) + offset + 0.4 + offset += len(gt_instances.bboxes) + batch_inputs_all = torch.cat((batch_inputs, batch_inputs_rot)) + batch_gt_instances_all = batch_gt_instances + batch_gt_rot + elif p < self.prob_rot + self.prob_flp: # flp + # Generate flipped images and gts + batch_inputs_flp = transforms.functional.vflip(batch_inputs) + batch_gt_flp = copy.deepcopy(batch_gt_instances) + offset = 1 + for gt_instances in batch_gt_flp: + gt_instances.bboxes.flip_(batch_inputs.shape[2:4], + 'vertical') + gt_instances.bids = torch.arange( + 0, + len(gt_instances.bboxes), + 1, + device=gt_instances.bboxes.device) + offset + 0.6 + offset += len(gt_instances.bboxes) + batch_inputs_all = torch.cat((batch_inputs, batch_inputs_flp)) + batch_gt_instances_all = batch_gt_instances + batch_gt_flp + else: # sca + # Generate scaled images and gts + sca = torch.rand( + 1, device=batch_inputs.device + ) * (self.sca_range[1] - self.sca_range[0]) + self.sca_range[0] + size = (self.crop_size[0] / sca).long() + batch_inputs_sca = transforms.functional.resized_crop( + batch_inputs, + 0, + 0, + size, + size, + self.crop_size, + antialias=False) + batch_gt_sca = copy.deepcopy(batch_gt_instances) + offset = 1 + for gt_instances in batch_gt_sca: + gt_instances.bboxes.rescale_((sca, sca)) + gt_instances.bids = torch.arange( + 0, + len(gt_instances.bboxes), + 1, + device=gt_instances.bboxes.device) + offset + 0.8 + offset += len(gt_instances.bboxes) + batch_inputs_all = torch.cat((batch_inputs, batch_inputs_sca)) + batch_gt_instances_all = batch_gt_instances + batch_gt_sca + + batch_gt_instances_filtered = [] + for gt_instances in batch_gt_instances_all: + H = self.crop_size[0] + D = 16 + ignore_mask = torch.logical_or( + gt_instances.bboxes.tensor[:, :2].min(1)[0] < D, + gt_instances.bboxes.tensor[:, :2].max(1)[0] > H - D) + gt_instances_filtered = InstanceData() + gt_instances_filtered.bboxes = RotatedBoxes( + gt_instances.bboxes.tensor[~ignore_mask]) + gt_instances_filtered.labels = gt_instances.labels[ + ~ignore_mask] + gt_instances_filtered.bids = gt_instances.bids[~ignore_mask] + batch_gt_instances_filtered.append(gt_instances_filtered) + + batch_data_samples_all = [] + for gt_instances in batch_gt_instances_filtered: + data_sample = DetDataSample( + metainfo=batch_data_samples[0].metainfo) + data_sample.gt_instances = gt_instances + batch_data_samples_all.append(data_sample) + + else: + offset = 1 + for gt_instances in batch_gt_instances: + gt_instances.bids = torch.arange( + 0, + len(gt_instances.bboxes), + 1, + device=gt_instances.bboxes.device) + offset + 0.2 + offset += len(gt_instances.bboxes) + + batch_inputs_all = batch_inputs + batch_data_samples_all = [] + for gt_instances in batch_gt_instances: + data_sample = DetDataSample( + metainfo=batch_data_samples[0].metainfo) + gt_instances.bboxes = RotatedBoxes(gt_instances.bboxes) + data_sample.gt_instances = gt_instances + batch_data_samples_all.append(data_sample) + + # Plot + if self.debug: + import cv2 + import numpy as np + idx = np.random.randint(100) + B = batch_inputs.shape[0] + batch_inputs_plot = batch_inputs_all[::B] + batch_data_samples_plot = batch_data_samples_all[::B] + for i in range(len(batch_inputs_plot)): + img = np.ascontiguousarray( + (batch_inputs_plot[i].permute(1, 2, 0).cpu().numpy() + + 127)) + bb = batch_data_samples_plot[i].gt_instances.bboxes + for b in bb.cpu().numpy(): + point2rbox_generator.plot_one_rotated_box(img, b) + cv2.imwrite(f'{idx}-{i}.png', img) + + feat = self.extract_feat(batch_inputs_all) + losses = self.bbox_head.loss(feat, batch_data_samples_all) + + return losses diff --git a/mmrotate/models/task_modules/assigners/__init__.py b/mmrotate/models/task_modules/assigners/__init__.py index 0e1d912b3..7c7a37d17 100644 --- a/mmrotate/models/task_modules/assigners/__init__.py +++ b/mmrotate/models/task_modules/assigners/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .convex_assigner import ConvexAssigner from .max_convex_iou_assigner import MaxConvexIoUAssigner +from .point2rbox_assigner import Point2RBoxAssigner from .rotate_iou2d_calculator import (FakeRBboxOverlaps2D, QBbox2HBboxOverlaps2D, RBbox2HBboxOverlaps2D, RBboxOverlaps2D) @@ -10,5 +11,5 @@ __all__ = [ 'ConvexAssigner', 'MaxConvexIoUAssigner', 'SASAssigner', 'RotatedATSSAssigner', 'RBboxOverlaps2D', 'FakeRBboxOverlaps2D', - 'RBbox2HBboxOverlaps2D', 'QBbox2HBboxOverlaps2D' + 'RBbox2HBboxOverlaps2D', 'QBbox2HBboxOverlaps2D', 'Point2RBoxAssigner' ] diff --git a/mmrotate/models/task_modules/assigners/point2rbox_assigner.py b/mmrotate/models/task_modules/assigners/point2rbox_assigner.py new file mode 100755 index 000000000..2f2072634 --- /dev/null +++ b/mmrotate/models/task_modules/assigners/point2rbox_assigner.py @@ -0,0 +1,230 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmdet.models.task_modules.assigners.assign_result import AssignResult +from mmdet.models.task_modules.assigners.base_assigner import BaseAssigner +from mmdet.structures.bbox import bbox_xyxy_to_cxcywh +from mmdet.utils import ConfigType +from mmengine.structures import InstanceData + +from mmrotate.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class Point2RBoxAssigner(BaseAssigner): + """Point2RBoxAssigner between the priors and gt boxes, which can achieve + balance in positive priors, and gt_bboxes_ignore was not considered for + now. + + Args: + pos_ignore_thr (float): the threshold to ignore positive priors + neg_ignore_thr (float): the threshold to ignore negative priors + match_times(int): Number of positive priors for each gt box. + Defaults to 4. + iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou + calculator. Defaults to ``dict(type='BboxOverlaps2D')`` + """ + + def __init__( + self, + pos_ignore_thr: float, + neg_ignore_thr: float, + match_times: int = 4, + iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D') + ) -> None: + self.match_times = match_times + self.pos_ignore_thr = pos_ignore_thr + self.neg_ignore_thr = neg_ignore_thr + self.iou_calculator = TASK_UTILS.build(iou_calculator) + + def obb2xyxy(self, obb): + w = obb[:, 2::5] + h = obb[:, 3::5] + a = obb[:, 4::5].detach() + cosa = torch.cos(a).abs() + sina = torch.sin(a).abs() + hbbox_w = cosa * w + sina * h + hbbox_h = sina * w + cosa * h + dx = obb[..., 0] + dy = obb[..., 1] + dw = hbbox_w.reshape(-1) + dh = hbbox_h.reshape(-1) + x1 = dx - dw / 2 + y1 = dy - dh / 2 + x2 = dx + dw / 2 + y2 = dy + dh / 2 + return torch.stack((x1, y1, x2, y2), -1) + + def assign( + self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None + ) -> AssignResult: + """Assign gt to priors. + + The assignment is done in following steps + + 1. assign -1 by default + 2. compute the L1 cost between boxes. Note that we use priors and + predict boxes both + 3. compute the ignore indexes use gt_bboxes and predict boxes + 4. compute the ignore indexes of positive sample use priors and + predict boxes + + + Args: + pred_instances (:obj:`InstaceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be priors, points, or bboxes predicted by the model, + shape(n, 4). + gt_instances (:obj:`InstaceData`): Ground truth of instance + annotations. It usually includes ``bboxes`` and ``labels`` + attributes. + gt_instances_ignore (:obj:`InstaceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + + Returns: + :obj:`AssignResult`: The assign result. + """ + + gt_bboxes = gt_instances.bboxes.tensor + gt_labels = gt_instances.labels + gt_bids = gt_instances.bids + priors = pred_instances.priors + bbox_pred = pred_instances.decoder_priors + cls_scores = pred_instances.cls_scores + + num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) + + # 1. assign -1 by default + assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), + 0, + dtype=torch.long) + assigned_labels = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + assign_result = AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + assign_result.set_extra_property( + 'pos_bbox_mask', bbox_pred.new_empty(0, dtype=torch.bool)) + assign_result.set_extra_property( + 'pos_point_mask', bbox_pred.new_empty(0, dtype=torch.bool)) + assign_result.set_extra_property( + 'pos_predicted_boxes', + bbox_pred.new_empty((0, bbox_pred.shape[-1]))) + assign_result.set_extra_property( + 'target_boxes', gt_bboxes.new_empty((0, gt_bboxes.shape[-1]))) + assign_result.set_extra_property('target_labels', + gt_labels.new_empty((0, ))) + assign_result.set_extra_property('target_bids', + gt_bids.new_empty((0, ))) + return assign_result + + # 2. Compute the L1 cost between boxes + # Note that we use priors and predict boxes both + bbox_pred_xyxy = bbox_pred[:, :4] + point_mask = gt_bboxes[:, 2] < 1 + gt_bboxes_xyxy = self.obb2xyxy(gt_bboxes) + + cost_center = torch.cdist( + bbox_xyxy_to_cxcywh(bbox_pred_xyxy)[:, :2], gt_bboxes[:, :2], p=1) + cost_cls_scores = 1 - cls_scores[:, gt_labels].sigmoid() + cost_cls_scores[cost_center > 32] = 1e5 + + cost_bbox = cost_cls_scores.clone() + cost_bbox_priors = torch.cdist( + bbox_xyxy_to_cxcywh(priors), + bbox_xyxy_to_cxcywh(gt_bboxes_xyxy), + p=1) * cost_cls_scores + cost_bbox[:, point_mask] = 1e9 + cost_bbox_priors[:, point_mask] = 1e9 + + # 32 is the L1-dist between two adjacent diagonal anchors (stride=16) + cost_cls_scores[:, ~point_mask] = 1e9 + + # We found that topk function has different results in cpu and + # cuda mode. In order to ensure consistency with the source code, + # we also use cpu mode. + # TODO: Check whether the performance of cpu and cuda are the same. + C = cost_bbox.cpu() + C1 = cost_bbox_priors.cpu() + C2 = cost_cls_scores.cpu() + + # self.match_times x n + index = torch.topk(C, k=self.match_times, dim=0, largest=False)[1] + index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1] + index2 = torch.topk(C2, k=self.match_times, dim=0, largest=False)[1] + strong_idx = index2[:, point_mask.cpu()] + + # (self.match_times*2) x n + indexes = torch.cat((index, index1, index2), + dim=1).reshape(-1).to(bbox_pred.device) + + pred_overlaps = self.iou_calculator(bbox_pred_xyxy, gt_bboxes_xyxy) + anchor_overlaps = self.iou_calculator(priors, gt_bboxes_xyxy) + # anchor_overlaps[:, point_mask] = 1 + pred_max_overlaps, _ = pred_overlaps.max(dim=1) + anchor_max_overlaps, _ = anchor_overlaps.max(dim=0) + + # 3. Compute the ignore indexes use gt_bboxes and predict boxes + ignore_idx = pred_max_overlaps > self.neg_ignore_thr + assigned_gt_inds[ignore_idx] = -1 + + # 4. Compute the ignore indexes of positive sample use priors + # and predict boxes + pos_gt_index = torch.arange( + 0, C1.size(1), + device=bbox_pred.device).repeat(self.match_times * 3) + pos_ious = anchor_overlaps[indexes, pos_gt_index] + pos_ignore_idx = pos_ious < self.pos_ignore_thr + + # Bbox pos weight, the same as YOLOF + pos_bbox_idx = ~pos_ignore_idx + # Point pos weight, False for index and index1, True for index2 + pos_point_idx = pos_ignore_idx.new_zeros(self.match_times, 3, + C1.size(1)) + pos_point_idx[:, 2, point_mask] = True + pos_point_idx = pos_point_idx.reshape(-1) + pos_point_idx = torch.logical_and(pos_point_idx, pos_ious > 0) + # When the pos is ignored by both bbox and point, not assign gt label + pos_ignore_idx = torch.logical_and(pos_ignore_idx, ~pos_point_idx) + + pos_gt_index_with_ignore = pos_gt_index + 1 + pos_gt_index_with_ignore[pos_ignore_idx] = -1 + assigned_gt_inds[indexes] = pos_gt_index_with_ignore + + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + + assign_result = AssignResult( + num_gts, + assigned_gt_inds, + anchor_max_overlaps, + labels=assigned_labels) + assign_result.set_extra_property('pos_point_index', strong_idx) + assign_result.set_extra_property('pos_bbox_mask', pos_bbox_idx) + assign_result.set_extra_property('pos_point_mask', pos_point_idx) + assign_result.set_extra_property('pos_predicted_boxes', + bbox_pred[indexes]) + assign_result.set_extra_property('target_boxes', + gt_bboxes[pos_gt_index]) + assign_result.set_extra_property('target_labels', + gt_labels[pos_gt_index]) + assign_result.set_extra_property('target_bids', gt_bids[pos_gt_index]) + return assign_result diff --git a/mmrotate/models/task_modules/synthesis_generators/point2rbox_generator.py b/mmrotate/models/task_modules/synthesis_generators/point2rbox_generator.py new file mode 100755 index 000000000..e94ec9f9c --- /dev/null +++ b/mmrotate/models/task_modules/synthesis_generators/point2rbox_generator.py @@ -0,0 +1,282 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import cv2 +import numpy as np +import torch +import torchvision +from mmcv.ops import nms_rotated + + +def get_pattern_fill(w, h): + w, h = int(w), int(h) + p = np.ones((h, w)) + cv2.line(p, (0, 0), (0, h - 1), 0.01, 1) + cv2.line(p, (0, 0), (w - 1, 0), 0.01, 1) + cv2.line(p, (w - 1, h - 1), (0, h - 1), 0.01, 1) + cv2.line(p, (w - 1, h - 1), (w - 1, 0), 0.01, 1) + return p + + +def get_pattern_line(w, h): + w, h = int(w), int(h) + p = np.ones((h, w)) + interval_range = [3, 6] + xn = np.random.randint(*interval_range) + yn = np.random.randint(*interval_range) + for i in range(xn): + x = np.intp(np.round((w - 1) * i / (xn - 1))) + cv2.line(p, (x, 0), (x, h), 0.5, 1) + for i in range(yn): + y = np.intp(np.round((h - 1) * i / (yn - 1))) + cv2.line(p, (0, y), (w, y), 0.5, 1) + return torch.from_numpy(p) + + +def get_pattern_rose(w, h): + w, h = int(w), int(h) + p = np.ones((h, w), dtype=float) + t = np.float32(range(100)) + omega_range = [2, 4] + xn = np.random.randint(*omega_range) + x = np.sin(t / 99 * 2 * np.pi) * np.cos( + t / 100 * 2 * np.pi * xn) * w / 2 + w / 2 + y = np.cos(t / 99 * 2 * np.pi) * np.cos( + t / 100 * 2 * np.pi * 2) * h / 2 + h / 2 + xy = np.stack((x, y), -1) + cv2.polylines(p, np.int_([xy]), True, 0.5, 1) + return torch.from_numpy(p) + + +def get_pattern_li(w, h): + w, h = int(w), int(h) + p = np.ones((h, w), dtype=float) + t = np.float32(range(100)) + s = np.random.rand() * 8 + s2 = np.random.rand() * 0.5 + 0.1 + r = (np.abs(np.cos(t / 99 * 4 * np.pi))**s) * (1 - s2) + s2 + x = r * np.sin(t / 99 * 2 * np.pi) * w / 2 + w / 2 + y = r * np.cos(t / 99 * 2 * np.pi) * h / 2 + h / 2 + xy = np.stack((x, y), -1) + cv2.fillPoly(p, np.int_([xy]), 1) + cv2.polylines(p, np.int_([xy]), True, 0.5, 1) + return torch.from_numpy(p) + + +def obb2xyxy(obb): + w = obb[:, 2] + h = obb[:, 3] + a = obb[:, 4] + cosa = torch.cos(a).abs() + sina = torch.sin(a).abs() + dw = cosa * w + sina * h + dh = sina * w + cosa * h + dx = obb[..., 0] + dy = obb[..., 1] + x1 = dx - dw / 2 + y1 = dy - dh / 2 + x2 = dx + dw / 2 + y2 = dy + dh / 2 + return torch.stack((x1, y1, x2, y2), -1) + + +def plot_one_rotated_box(img, + obb, + color=[0.0, 0.0, 128], + label=None, + line_thickness=None): + width, height, theta = obb[2], obb[3], obb[4] / np.pi * 180 + if theta < 0: + width, height, theta = height, width, theta + 90 + rect = [(obb[0], obb[1]), (width, height), theta] + poly = np.intp(np.round( + cv2.boxPoints(rect))) # [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] + cv2.drawContours( + image=img, contours=[poly], contourIdx=-1, color=color, thickness=2) + + +def get_pattern_gaussian(w, h, device): + w, h = int(w), int(h) + y, x = torch.meshgrid( + torch.arange(h, device=device), + torch.arange(w, device=device), + indexing='ij') + y = (y - h / 2) / (h / 2) + x = (x - w / 2) / (w / 2) + ox, oy = torch.randn(2, device=device).clip(-3, 3) * 0.15 + sx, sy = torch.rand(2, device=device) + 0.3 + z = torch.exp(-((x - ox) * sx)**2 - ((y - oy) * sy)**2) * 0.9 + 0.1 + return z + + +def generate_sythesis(img, bb_occupied, sca_fact, pattern, prior_size, + dense_cls, imgsize): + device = img.device + cen_range = [50, imgsize - 50] + + base_scale = (torch.randn(1, device=device) * 0.4).clamp(-1, 1) * sca_fact + base_scale = torch.exp(base_scale) + + bb_occupied = bb_occupied.clone() + bb_occupied[:, 2] = prior_size[bb_occupied[:, 6].long(), 0] * 0.7 + bb_occupied[:, 3] = prior_size[bb_occupied[:, 6].long(), 0] * 0.7 + bb_occupied[:, 4] = 0 + + bb = [] + palette = [ + torch.zeros(0, 6, device=device) for _ in range(len(prior_size)) + ] + adjboost = 2 + for b in bb_occupied: + x, y = torch.rand( + 2, device=device) * (cen_range[1] - cen_range[0]) + cen_range[0] + dw = prior_size[b[6].long(), 2] + w = (torch.randn(1, device=device) * 0.4).clamp(-1, 1) * dw + w = base_scale * torch.exp(w) + dr = prior_size[b[6].long(), 3] + r = (torch.randn(1, device=device) * 0.4).clamp(-1, 1) * dr + h = w * torch.exp(r) + w *= prior_size[b[6].long(), 0] + h *= prior_size[b[6].long(), 1] + a = torch.rand(1, device=device) * torch.pi + x = x.clip(0.71 * w, imgsize - 1 - 0.71 * w) + y = y.clip(0.71 * h, imgsize - 1 - 0.71 * h) + bb.append([x, y, w, h, a, (w * h) / imgsize / imgsize + 0.1, b[6]]) + + bx = torch.clip(b[0:2], 16, imgsize - 1 - 16).long() + nbr0 = img[:, bx[1] - 2:bx[1] + 3, bx[0] - 2:bx[0] + 3].reshape(3, -1) + nbr1 = img[:, bx[1] - 16:bx[1] + 17, + bx[0] - 16:bx[0] + 17].reshape(3, -1) + c0 = nbr0.mean(1) + c1 = nbr1[:, (nbr1.mean(0) - c0.mean()).abs().max(0)[1]] + c = torch.cat((c0, c1), 0) + palette[b[6].long()] = torch.cat((palette[b[6].long()], c[None]), 0) + + if np.random.random() < 0.2 and adjboost > 0: + adjboost -= 1 + if b[6].long() in dense_cls: + itv, dev = torch.rand( + 1, device=device) * 4 + 2, torch.rand( + 1, device=device) * 8 - 4 + ofx = (h + itv) * torch.sin(-a) + dev * torch.cos(a) + ofy = (h + itv) * torch.cos(a) + dev * torch.sin(a) + for k in range(1, 6): + bb.append([ + x + k * ofx, y + k * ofy, w, h, a, + (w * h) / imgsize / imgsize + 0.1 - 0.001 * k, b[6] + ]) + else: + itv, dev = torch.rand( + 1, device=device) * 40 + 10, torch.rand( + 1, device=device) * 0 + ofx = (h + itv) * torch.sin(-a) + dev * torch.cos(a) + ofy = (h + itv) * torch.cos(a) + dev * torch.sin(a) + for k in range(1, 4): + bb.append([ + x + k * ofx, y + k * ofy, w, h, a, + (w * h) / imgsize / imgsize + 0.1 - 0.001 * k, b[6] + ]) + + bb = torch.tensor(bb) + bb = torch.cat((bb_occupied, bb), 0) + _, keep = nms_rotated(bb[:, 0:5], bb[:, 5], 0.05) + bb = bb[keep] + bb = bb[bb[:, 5] < 1] + + xyxy = obb2xyxy(bb) + mask = torch.logical_and( + xyxy.min(-1)[0] >= 0, + xyxy.max(-1)[0] <= imgsize - 1) + bb, xyxy = bb[mask], xyxy[mask] + + for i in range(len(bb)): + cx, cy, w, h, t, s, c = bb[i, :7] + ox, oy = torch.floor(xyxy[i, 0:2]).long() + ex, ey = torch.ceil(xyxy[i, 2:4]).long() + sx, sy = ex - ox, ey - oy + theta = torch.tensor( + [[1 / w * torch.cos(t), 1 / w * torch.sin(t), 0], + [1 / h * torch.sin(-t), 1 / h * torch.cos(t), 0]], + dtype=torch.float, + device=device) + theta[:, :2] @= torch.tensor([[sx, 0], [0, sy]], + dtype=torch.float, + device=device) + grid = torch.nn.functional.affine_grid( + theta[None], (1, 1, sy, sx), align_corners=False) + p = pattern[c.long()] + p = p[np.random.randint(0, len(p))][None].clone() + trans = torchvision.transforms.Compose([ + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.RandomVerticalFlip(), + ]) + if np.random.random() < 0.2: + p *= get_pattern_line(p.shape[2], p.shape[1]) + if np.random.random() < 0.2: + p *= get_pattern_rose(p.shape[2], p.shape[1]) + if np.random.random() < 0.2: + p *= get_pattern_li(p.shape[2], p.shape[1]) + p = trans(p) + p = torch.nn.functional.grid_sample( + p[None], grid, align_corners=False, mode='nearest')[0] + if np.random.random() < 0.9: + a = get_pattern_gaussian(sx, sy, device) * (p != 0) + else: + a = (p != 0).float() + pal = palette[c.long()] + color = pal[torch.randint(0, len(pal), (1, ))][0] + p = p * color[:3, None, None] + (1 - p) * color[3:, None, None] + + img[:, oy:oy + sy, + ox:ox + sx] = (1 - a) * img[:, oy:oy + sy, ox:ox + sx] + a * p + + return img, bb + + +def load_basic_pattern(path, set_rc=True, set_sk=True): + with open(os.path.join(path, 'properties.txt')) as f: + prior = eval(f.read()) + prior_size = torch.tensor(list(prior.values())) + pattern = [] + for i in range(len(prior_size)): + p = [] + if set_rc: + img = get_pattern_fill(*prior_size[i, (0, 1)]) + p.append(torch.from_numpy(img).float()) + if set_sk: + if os.path.exists(os.path.join(path, f'{i}.png')): + img = cv2.imread(os.path.join(path, f'{i}.png'), 0) + img = img / 255 + else: + img = get_pattern_fill(*prior_size[i, (0, 1)]) + p.append(torch.from_numpy(img).float()) + pattern.append(p) + return pattern, prior_size + + +if __name__ == '__main__': + pattern, prior_size = load_basic_pattern( + '../../../../data/basic_patterns/dota') + + img = cv2.imread('../../../../data/basic_patterns/test/P2805.png') - 127.0 + img = torch.from_numpy(img).permute(2, 0, 1).float().contiguous() + bb = torch.tensor(((318, 84, 0, 0, 0.5, 1, 4), ) * 10) + + import time + c = time.time() + # img, bb1 = generate_sythesis(img, bb, 1) + # print(time.time() - c) + # bb1[:, 5] = 1 + c = time.time() + img, bb2 = generate_sythesis(img, bb, 0.5, pattern, prior_size, (4, 5, 6), + 1024) + print(time.time() - c) + + img = img.permute(1, 2, 0).contiguous().cpu().numpy() + # for b in bb1.cpu().numpy(): + # plot_one_rotated_box(img, b, color=[0, 100, 0]) + # for b in bb2.cpu().numpy(): + # plot_one_rotated_box(img, b, color=[0, 0, 100]) + cv2.imshow('', (img + 127) / 256) + # cv2.imwrite('s1.png', (img + 127)) + cv2.waitKey()