diff --git a/README.md b/README.md index 6e53a8e..2a39a47 100644 --- a/README.md +++ b/README.md @@ -155,12 +155,13 @@ read https://github.com/NVIDIA-AI-IOT/torch2trt#how-does-it-work for detail. - [x] Mask R-CNN(experiment) - [x] Cascade Mask R-CNN(experiment) - [x] Cascade RPN +- [x] DETR Tested on: - torch=1.6.0 - tensorrt=7.1.3.4 -- mmdetection=2.5.0 +- mmdetection=2.10.0 - cuda=10.2 - cudnn=8.0.2.39 diff --git a/mmdet2trt/core/bbox/__init__.py b/mmdet2trt/core/bbox/__init__.py index 4545435..6ae676c 100644 --- a/mmdet2trt/core/bbox/__init__.py +++ b/mmdet2trt/core/bbox/__init__.py @@ -1,3 +1,3 @@ from .coder import * from .iou_calculators import * -from .transforms import batched_distance2bbox \ No newline at end of file +from .transforms import batched_distance2bbox, batched_bbox_cxcywh_to_xyxy \ No newline at end of file diff --git a/mmdet2trt/core/bbox/transforms.py b/mmdet2trt/core/bbox/transforms.py index 09605e4..d318250 100644 --- a/mmdet2trt/core/bbox/transforms.py +++ b/mmdet2trt/core/bbox/transforms.py @@ -23,3 +23,13 @@ def bbox2roi(proposals): proposals = proposals.view(-1, 4) rois = torch.cat([rois_pad, proposals], dim=1) return rois + + +def batched_bbox_cxcywh_to_xyxy(bbox): + cx = bbox[:,:,0] + cy = bbox[:,:,1] + w = bbox[:,:,2] + h = bbox[:,:,3] + bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)] + + return torch.stack(bbox_new, dim=-1) \ No newline at end of file diff --git a/mmdet2trt/models/backbones/base_backbone.py b/mmdet2trt/models/backbones/base_backbone.py index 4731a08..2443f59 100644 --- a/mmdet2trt/models/backbones/base_backbone.py +++ b/mmdet2trt/models/backbones/base_backbone.py @@ -9,6 +9,7 @@ @register_wraper("mmdet.models.backbones.Darknet") @register_wraper("mmdet.models.backbones.DetectoRS_ResNet") @register_wraper("mmdet.models.backbones.HourglassNet") +@register_wraper("mmdet.models.backbones.resnext.ResNeXt") class BaseBackboneWraper(nn.Module): def __init__(self, module): super(BaseBackboneWraper, self).__init__() diff --git a/mmdet2trt/models/dense_heads/__init__.py b/mmdet2trt/models/dense_heads/__init__.py index 7f83f5a..9fb6218 100644 --- a/mmdet2trt/models/dense_heads/__init__.py +++ b/mmdet2trt/models/dense_heads/__init__.py @@ -14,4 +14,5 @@ from .gfl_head import GFLHeadWraper from .centripetal_head import CentripetalHeadWraper from .vfnet_head import VFNetHeadWraper -from .cascade_rpn_head import StageCascadeRPNHeadWraper, CascadeRPNHeadWraper \ No newline at end of file +from .cascade_rpn_head import StageCascadeRPNHeadWraper, CascadeRPNHeadWraper +from .transformer_head import TransformerHeadWraper \ No newline at end of file diff --git a/mmdet2trt/models/dense_heads/ga_rpn_head.py b/mmdet2trt/models/dense_heads/ga_rpn_head.py index 58dd8ef..b7063b6 100644 --- a/mmdet2trt/models/dense_heads/ga_rpn_head.py +++ b/mmdet2trt/models/dense_heads/ga_rpn_head.py @@ -12,6 +12,12 @@ class GARPNHeadWraper(GuidedAnchorHeadWraper): def __init__(self, module): super(GARPNHeadWraper, self).__init__(module) + self.test_cfg = module.test_cfg + if 'nms' in self.test_cfg: + self.test_cfg.nms_thr = self.test_cfg.nms['iou_threshold'] + if 'max_per_img' in self.test_cfg: + self.test_cfg.nms_post = self.test_cfg.max_per_img + self.test_cfg.max_num = self.test_cfg.max_per_img self.rpn_nms = BatchedNMS(0.0, self.test_cfg.nms_thr, -1) def forward(self, feat, x): diff --git a/mmdet2trt/models/dense_heads/rpn_head.py b/mmdet2trt/models/dense_heads/rpn_head.py index df3c6c0..bbe2d0d 100644 --- a/mmdet2trt/models/dense_heads/rpn_head.py +++ b/mmdet2trt/models/dense_heads/rpn_head.py @@ -16,6 +16,10 @@ def __init__(self, module): self.bbox_coder = build_wraper(self.module.bbox_coder) self.test_cfg = module.test_cfg + if 'nms' in self.test_cfg: + self.test_cfg.nms_thr = self.test_cfg.nms['iou_threshold'] + if 'max_per_img' in self.test_cfg: + self.test_cfg.nms_post = self.test_cfg.max_per_img self.rpn_nms = BatchedNMS(0.0, self.test_cfg.nms_thr, -1) def forward(self, feat, x): diff --git a/mmdet2trt/models/dense_heads/transformer_head.py b/mmdet2trt/models/dense_heads/transformer_head.py new file mode 100644 index 0000000..5395dd3 --- /dev/null +++ b/mmdet2trt/models/dense_heads/transformer_head.py @@ -0,0 +1,65 @@ +import torch +from mmdet2trt.models.builder import register_wraper +import torch +from torch import nn +from torch.nn import functional as F + +from mmdet2trt.core.bbox.transforms import batched_bbox_cxcywh_to_xyxy + + +@register_wraper("mmdet.models.dense_heads.TransformerHead") +class TransformerHeadWraper(nn.Module): + def __init__(self, module): + super(TransformerHeadWraper, self).__init__() + self.module = module + self.test_cfg = module.test_cfg + + def module_forward(self, feats, x): + module = self.module + batch_size, _, input_img_h, input_img_w = x.shape + masks = feats[0].new_zeros((batch_size, input_img_h, input_img_w)) + + cls_scores = [] + bbox_preds = [] + for feat in feats: + feat = module.input_proj(feat) + masks_interp = F.interpolate(masks.unsqueeze(1), + size=feat.shape[-2:]).to( + torch.bool).squeeze(1) + pos_embed = module.positional_encoding( + masks_interp) # [bs, embed_dim, h, w] + # outs_dec: [nb_dec, bs, num_query, embed_dim] + outs_dec, _ = module.transformer(feat, masks_interp, + module.query_embedding.weight, + pos_embed) + + all_cls_scores = module.fc_cls(outs_dec) + all_bbox_preds = module.fc_reg( + module.activate(module.reg_ffn(outs_dec))).sigmoid() + + cls_scores.append(all_cls_scores) + bbox_preds.append(all_bbox_preds) + + return cls_scores, bbox_preds + + def forward(self, feats, x): + img_shape = x.shape[2:] + + cls_scores, bbox_preds = self.module_forward(feats, x) + + cls_score = cls_scores[-1][0] + bbox_pred = bbox_preds[-1][0] + + scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1) + bbox_pred = batched_bbox_cxcywh_to_xyxy(bbox_pred) + bbox_pred = bbox_pred.clamp(min=0, max=1) + bbox_pred0 = bbox_pred[:, :, 0] * img_shape[1] + bbox_pred1 = bbox_pred[:, :, 1] * img_shape[0] + bbox_pred2 = bbox_pred[:, :, 2] * img_shape[1] + bbox_pred3 = bbox_pred[:, :, 3] * img_shape[0] + bbox_pred = torch.stack( + [bbox_pred0, bbox_pred1, bbox_pred2, bbox_pred3], dim=-1) + + num_dets = (scores[:, :1] * 0).int() + scores.shape[1] + + return num_dets, bbox_pred, scores, det_labels diff --git a/mmdet2trt/models/detectors/single_stage.py b/mmdet2trt/models/detectors/single_stage.py index 75224fa..0f8a63d 100644 --- a/mmdet2trt/models/detectors/single_stage.py +++ b/mmdet2trt/models/detectors/single_stage.py @@ -19,6 +19,7 @@ @register_wraper("mmdet.models.RetinaNet") @register_wraper("mmdet.models.SingleStageDetector") @register_wraper("mmdet.models.VFNet") +@register_wraper("mmdet.models.DETR") class SingleStageDetectorWraper(nn.Module): def __init__(self, model, wrap_config={}): super(SingleStageDetectorWraper, self).__init__()