Skip to content

Commit

Permalink
better dcn and dcnv2
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Mar 1, 2021
1 parent 1af26d6 commit d225a7c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 84 deletions.
43 changes: 13 additions & 30 deletions mmdet2trt/converters/DeformConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ def convert_DeformConv(ctx):

input_trt = trt_(ctx.network, input)
offset_trt = trt_(ctx.network, offset)

kernel_size = weight.shape[2]
if not isinstance(kernel_size, tuple):
kernel_size = (kernel_size, ) * 2
weight_trt = trt_(ctx.network, weight)

if not isinstance(stride, tuple):
stride = (stride, ) * 2
Expand All @@ -35,21 +32,15 @@ def convert_DeformConv(ctx):
if not isinstance(dilation, tuple):
dilation = (dilation, ) * 2

kernel = weight.detach().cpu().numpy()
out_channels = output.shape[1]

plugin = create_dcn_plugin("dcn_" + str(id(input)),
out_channels=out_channels,
kernel_size=kernel_size,
W=kernel,
padding=padding,
stride=stride,
padding=padding,
dilation=dilation,
deformable_group=deform_groups,
group=groups)

custom_layer = ctx.network.add_plugin_v2(inputs=[input_trt, offset_trt],
plugin=plugin)
custom_layer = ctx.network.add_plugin_v2(
inputs=[input_trt, offset_trt, weight_trt], plugin=plugin)

output._trt = custom_layer.get_output(0)

Expand All @@ -73,10 +64,9 @@ def convert_ModulatedDeformConv(ctx):
input_trt = trt_(ctx.network, input)
offset_trt = trt_(ctx.network, offset)
mask_trt = trt_(ctx.network, mask)

kernel_size = weight.shape[2]
if not isinstance(kernel_size, tuple):
kernel_size = (kernel_size, ) * 2
weight_trt = trt_(ctx.network, weight)
if bias is not None:
bias_trt = trt_(ctx.network, bias)

if not isinstance(stride, tuple):
stride = (stride, ) * 2
Expand All @@ -87,24 +77,17 @@ def convert_ModulatedDeformConv(ctx):
if not isinstance(dilation, tuple):
dilation = (dilation, ) * 2

kernel = weight.detach().cpu().numpy()
out_channels = output.shape[1]

if bias is not None:
bias = bias.detach().cpu().numpy()

plugin = create_dcnv2_plugin("dcn_" + str(id(input)),
out_channels=out_channels,
kernel_size=kernel_size,
W=kernel,
B=bias,
padding=padding,
stride=stride,
padding=padding,
dilation=dilation,
deformable_group=deform_groups,
group=groups)

custom_layer = ctx.network.add_plugin_v2(
inputs=[input_trt, offset_trt, mask_trt], plugin=plugin)
layer_inputs = [input_trt, offset_trt, mask_trt, weight_trt]
if bias is not None:
layer_inputs += [bias_trt]
custom_layer = ctx.network.add_plugin_v2(inputs=layer_inputs,
plugin=plugin)

output._trt = custom_layer.get_output(0)
54 changes: 0 additions & 54 deletions mmdet2trt/converters/plugins/create_dcn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@


def create_dcn_plugin(layer_name,
out_channels,
kernel_size,
W,
type_id=trt.DataType.FLOAT,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
Expand All @@ -23,9 +19,6 @@ def create_dcn_plugin(layer_name,

creator = trt.get_plugin_registry().get_plugin_creator(
'DeformableConvPluginDynamic', '1', '')
if not isinstance(kernel_size, Iterable):
kernel_size = [kernel_size, kernel_size]

if not isinstance(stride, Iterable):
stride = [stride, stride]

Expand All @@ -37,24 +30,6 @@ def create_dcn_plugin(layer_name,

pfc = trt.PluginFieldCollection()

pf_out_channels = trt.PluginField("out_dims",
np.array([out_channels], dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_out_channels)

pf_kernel_size = trt.PluginField("kernel_size",
np.array(kernel_size, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_kernel_size)

pf_W = trt.PluginField("W", W, trt.PluginFieldType.FLOAT32)
pfc.append(pf_W)

pf_type_id = trt.PluginField("type_id", np.array([type_id],
dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_type_id)

pf_stride = trt.PluginField("stride", np.array(stride, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_stride)
Expand All @@ -81,10 +56,6 @@ def create_dcn_plugin(layer_name,


def create_dcnv2_plugin(layer_name,
out_channels,
kernel_size,
W,
B=None,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
Expand All @@ -94,9 +65,6 @@ def create_dcnv2_plugin(layer_name,
type_id = trt.DataType.FLOAT
creator = trt.get_plugin_registry().get_plugin_creator(
'ModulatedDeformableConvPluginDynamic', '1', '')
if not isinstance(kernel_size, Iterable):
kernel_size = [kernel_size, kernel_size]

if not isinstance(stride, Iterable):
stride = [stride, stride]

Expand All @@ -108,28 +76,6 @@ def create_dcnv2_plugin(layer_name,

pfc = trt.PluginFieldCollection()

pf_out_channels = trt.PluginField("out_dims",
np.array([out_channels], dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_out_channels)

pf_kernel_size = trt.PluginField("kernel_size",
np.array(kernel_size, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_kernel_size)

pf_W = trt.PluginField("W", W, trt.PluginFieldType.FLOAT32)
pfc.append(pf_W)

if B is not None:
pf_B = trt.PluginField("B", B, trt.PluginFieldType.FLOAT32)
pfc.append(pf_B)

pf_type_id = trt.PluginField("type_id", np.array([type_id],
dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_type_id)

pf_stride = trt.PluginField("stride", np.array(stride, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_stride)
Expand Down

0 comments on commit d225a7c

Please sign in to comment.