Skip to content

Commit

Permalink
add DETR, compact to mmdetection2.10
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Mar 13, 2021
1 parent 3efd446 commit aac0505
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ read https://github.com/NVIDIA-AI-IOT/torch2trt#how-does-it-work for detail.
- [x] Mask R-CNN(experiment)
- [x] Cascade Mask R-CNN(experiment)
- [x] Cascade RPN
- [x] DETR


Tested on:
- torch=1.6.0
- tensorrt=7.1.3.4
- mmdetection=2.5.0
- mmdetection=2.10.0
- cuda=10.2
- cudnn=8.0.2.39

Expand Down
2 changes: 1 addition & 1 deletion mmdet2trt/core/bbox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .coder import *
from .iou_calculators import *
from .transforms import batched_distance2bbox
from .transforms import batched_distance2bbox, batched_bbox_cxcywh_to_xyxy
10 changes: 10 additions & 0 deletions mmdet2trt/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,13 @@ def bbox2roi(proposals):
proposals = proposals.view(-1, 4)
rois = torch.cat([rois_pad, proposals], dim=1)
return rois


def batched_bbox_cxcywh_to_xyxy(bbox):
cx = bbox[:,:,0]
cy = bbox[:,:,1]
w = bbox[:,:,2]
h = bbox[:,:,3]
bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]

return torch.stack(bbox_new, dim=-1)
1 change: 1 addition & 0 deletions mmdet2trt/models/backbones/base_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
@register_wraper("mmdet.models.backbones.Darknet")
@register_wraper("mmdet.models.backbones.DetectoRS_ResNet")
@register_wraper("mmdet.models.backbones.HourglassNet")
@register_wraper("mmdet.models.backbones.resnext.ResNeXt")
class BaseBackboneWraper(nn.Module):
def __init__(self, module):
super(BaseBackboneWraper, self).__init__()
Expand Down
3 changes: 2 additions & 1 deletion mmdet2trt/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
from .gfl_head import GFLHeadWraper
from .centripetal_head import CentripetalHeadWraper
from .vfnet_head import VFNetHeadWraper
from .cascade_rpn_head import StageCascadeRPNHeadWraper, CascadeRPNHeadWraper
from .cascade_rpn_head import StageCascadeRPNHeadWraper, CascadeRPNHeadWraper
from .transformer_head import TransformerHeadWraper
6 changes: 6 additions & 0 deletions mmdet2trt/models/dense_heads/ga_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ class GARPNHeadWraper(GuidedAnchorHeadWraper):
def __init__(self, module):
super(GARPNHeadWraper, self).__init__(module)

self.test_cfg = module.test_cfg
if 'nms' in self.test_cfg:
self.test_cfg.nms_thr = self.test_cfg.nms['iou_threshold']
if 'max_per_img' in self.test_cfg:
self.test_cfg.nms_post = self.test_cfg.max_per_img
self.test_cfg.max_num = self.test_cfg.max_per_img
self.rpn_nms = BatchedNMS(0.0, self.test_cfg.nms_thr, -1)

def forward(self, feat, x):
Expand Down
4 changes: 4 additions & 0 deletions mmdet2trt/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def __init__(self, module):
self.bbox_coder = build_wraper(self.module.bbox_coder)

self.test_cfg = module.test_cfg
if 'nms' in self.test_cfg:
self.test_cfg.nms_thr = self.test_cfg.nms['iou_threshold']
if 'max_per_img' in self.test_cfg:
self.test_cfg.nms_post = self.test_cfg.max_per_img
self.rpn_nms = BatchedNMS(0.0, self.test_cfg.nms_thr, -1)

def forward(self, feat, x):
Expand Down
65 changes: 65 additions & 0 deletions mmdet2trt/models/dense_heads/transformer_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from mmdet2trt.models.builder import register_wraper
import torch
from torch import nn
from torch.nn import functional as F

from mmdet2trt.core.bbox.transforms import batched_bbox_cxcywh_to_xyxy


@register_wraper("mmdet.models.dense_heads.TransformerHead")
class TransformerHeadWraper(nn.Module):
def __init__(self, module):
super(TransformerHeadWraper, self).__init__()
self.module = module
self.test_cfg = module.test_cfg

def module_forward(self, feats, x):
module = self.module
batch_size, _, input_img_h, input_img_w = x.shape
masks = feats[0].new_zeros((batch_size, input_img_h, input_img_w))

cls_scores = []
bbox_preds = []
for feat in feats:
feat = module.input_proj(feat)
masks_interp = F.interpolate(masks.unsqueeze(1),
size=feat.shape[-2:]).to(
torch.bool).squeeze(1)
pos_embed = module.positional_encoding(
masks_interp) # [bs, embed_dim, h, w]
# outs_dec: [nb_dec, bs, num_query, embed_dim]
outs_dec, _ = module.transformer(feat, masks_interp,
module.query_embedding.weight,
pos_embed)

all_cls_scores = module.fc_cls(outs_dec)
all_bbox_preds = module.fc_reg(
module.activate(module.reg_ffn(outs_dec))).sigmoid()

cls_scores.append(all_cls_scores)
bbox_preds.append(all_bbox_preds)

return cls_scores, bbox_preds

def forward(self, feats, x):
img_shape = x.shape[2:]

cls_scores, bbox_preds = self.module_forward(feats, x)

cls_score = cls_scores[-1][0]
bbox_pred = bbox_preds[-1][0]

scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
bbox_pred = batched_bbox_cxcywh_to_xyxy(bbox_pred)
bbox_pred = bbox_pred.clamp(min=0, max=1)
bbox_pred0 = bbox_pred[:, :, 0] * img_shape[1]
bbox_pred1 = bbox_pred[:, :, 1] * img_shape[0]
bbox_pred2 = bbox_pred[:, :, 2] * img_shape[1]
bbox_pred3 = bbox_pred[:, :, 3] * img_shape[0]
bbox_pred = torch.stack(
[bbox_pred0, bbox_pred1, bbox_pred2, bbox_pred3], dim=-1)

num_dets = (scores[:, :1] * 0).int() + scores.shape[1]

return num_dets, bbox_pred, scores, det_labels
1 change: 1 addition & 0 deletions mmdet2trt/models/detectors/single_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
@register_wraper("mmdet.models.RetinaNet")
@register_wraper("mmdet.models.SingleStageDetector")
@register_wraper("mmdet.models.VFNet")
@register_wraper("mmdet.models.DETR")
class SingleStageDetectorWraper(nn.Module):
def __init__(self, model, wrap_config={}):
super(SingleStageDetectorWraper, self).__init__()
Expand Down

0 comments on commit aac0505

Please sign in to comment.