Skip to content

Commit

Permalink
better anchor generator
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Mar 5, 2021
1 parent d225a7c commit 3efd446
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 71 deletions.
54 changes: 28 additions & 26 deletions mmdet2trt/converters/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,35 @@ def convert_AnchorGeneratorDynamic(ctx):
stride = get_arg(ctx, 'stride', pos=2, default=base_size)
if hasattr(ag.generator, 'base_anchors'):
base_anchors = ag.generator.base_anchors[index]
base_anchors = base_anchors.view(-1).cpu().numpy()
plugin = create_gridanchordynamic_plugin("ag_" + str(id(module)),
base_size=base_size,
stride=stride,
base_anchors=base_anchors)
else:
scales = ag.scales.detach().cpu().numpy().astype(np.float32)
ratios = ag.ratios.detach().cpu().numpy().astype(np.float32)
scale_major = ag.scale_major
ctr = ag.ctr
if ctr is None:
# center_x = -1
# center_y = -1
center_x = 0
center_y = 0
else:
center_x, center_y = ag.ctr
# base_anchors = base_anchors.view(-1).cpu().numpy()
base_anchors_trt = trt_(ctx.network, base_anchors.float())

plugin = create_gridanchordynamic_plugin("ag_" + str(id(module)),
base_size=base_size,
stride=stride,
scales=scales,
ratios=ratios,
scale_major=scale_major,
center_x=center_x,
center_y=center_y)

custom_layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin)
stride=stride)
else:
print("no base_anchors in {}".format(ag.generator))
# scales = ag.scales.detach().cpu().numpy().astype(np.float32)
# ratios = ag.ratios.detach().cpu().numpy().astype(np.float32)
# scale_major = ag.scale_major
# ctr = ag.ctr
# if ctr is None:
# # center_x = -1
# # center_y = -1
# center_x = 0
# center_y = 0
# else:
# center_x, center_y = ag.ctr

# plugin = create_gridanchordynamic_plugin("ag_" + str(id(module)),
# base_size=base_size,
# stride=stride,
# scales=scales,
# ratios=ratios,
# scale_major=scale_major,
# center_x=center_x,
# center_y=center_y)

custom_layer = ctx.network.add_plugin_v2(
inputs=[input_trt, base_anchors_trt], plugin=plugin)

output._trt = custom_layer.get_output(0)
46 changes: 1 addition & 45 deletions mmdet2trt/converters/plugins/create_gridanchordynamic_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,15 @@


def create_gridanchordynamic_plugin(layer_name,
base_size,
stride,
scales=np.array([1.]),
ratios=np.array([1.]),
scale_major=True,
center_x=-1,
center_y=-1,
base_anchors=None):
stride):

creator = trt.get_plugin_registry().get_plugin_creator(
'GridAnchorDynamicPluginDynamic', '1', '')

pfc = trt.PluginFieldCollection()

pf_base_size = trt.PluginField("base_size",
np.array([base_size], dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_base_size)

pf_stride = trt.PluginField("stride", np.array([stride], dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_stride)

pf_scales = trt.PluginField("scales",
np.array(scales).astype(np.float32),
trt.PluginFieldType.FLOAT32)
pfc.append(pf_scales)

pf_ratios = trt.PluginField("ratios",
np.array(ratios).astype(np.float32),
trt.PluginFieldType.FLOAT32)
pfc.append(pf_ratios)

pf_scale_major = trt.PluginField(
"scale_major", np.array([int(scale_major)], dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_scale_major)

pf_center_x = trt.PluginField("center_x",
np.array([center_x], dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_center_x)

pf_center_y = trt.PluginField("center_y",
np.array([center_y], dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_center_y)

if base_anchors is not None:
pf_base_anchors = trt.PluginField(
"base_anchors",
np.array(base_anchors).astype(np.float32),
trt.PluginFieldType.FLOAT32)
pfc.append(pf_base_anchors)

return creator.create_plugin(layer_name, pfc)

0 comments on commit 3efd446

Please sign in to comment.