diff --git a/mmdet2trt/converters/anchor_generator.py b/mmdet2trt/converters/anchor_generator.py index 80f66d2..ce451d2 100644 --- a/mmdet2trt/converters/anchor_generator.py +++ b/mmdet2trt/converters/anchor_generator.py @@ -21,33 +21,35 @@ def convert_AnchorGeneratorDynamic(ctx): stride = get_arg(ctx, 'stride', pos=2, default=base_size) if hasattr(ag.generator, 'base_anchors'): base_anchors = ag.generator.base_anchors[index] - base_anchors = base_anchors.view(-1).cpu().numpy() - plugin = create_gridanchordynamic_plugin("ag_" + str(id(module)), - base_size=base_size, - stride=stride, - base_anchors=base_anchors) - else: - scales = ag.scales.detach().cpu().numpy().astype(np.float32) - ratios = ag.ratios.detach().cpu().numpy().astype(np.float32) - scale_major = ag.scale_major - ctr = ag.ctr - if ctr is None: - # center_x = -1 - # center_y = -1 - center_x = 0 - center_y = 0 - else: - center_x, center_y = ag.ctr + # base_anchors = base_anchors.view(-1).cpu().numpy() + base_anchors_trt = trt_(ctx.network, base_anchors.float()) plugin = create_gridanchordynamic_plugin("ag_" + str(id(module)), - base_size=base_size, - stride=stride, - scales=scales, - ratios=ratios, - scale_major=scale_major, - center_x=center_x, - center_y=center_y) - - custom_layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) + stride=stride) + else: + print("no base_anchors in {}".format(ag.generator)) + # scales = ag.scales.detach().cpu().numpy().astype(np.float32) + # ratios = ag.ratios.detach().cpu().numpy().astype(np.float32) + # scale_major = ag.scale_major + # ctr = ag.ctr + # if ctr is None: + # # center_x = -1 + # # center_y = -1 + # center_x = 0 + # center_y = 0 + # else: + # center_x, center_y = ag.ctr + + # plugin = create_gridanchordynamic_plugin("ag_" + str(id(module)), + # base_size=base_size, + # stride=stride, + # scales=scales, + # ratios=ratios, + # scale_major=scale_major, + # center_x=center_x, + # center_y=center_y) + + custom_layer = ctx.network.add_plugin_v2( + inputs=[input_trt, base_anchors_trt], plugin=plugin) output._trt = custom_layer.get_output(0) diff --git a/mmdet2trt/converters/plugins/create_gridanchordynamic_plugin.py b/mmdet2trt/converters/plugins/create_gridanchordynamic_plugin.py index 024abb1..35091c3 100644 --- a/mmdet2trt/converters/plugins/create_gridanchordynamic_plugin.py +++ b/mmdet2trt/converters/plugins/create_gridanchordynamic_plugin.py @@ -11,59 +11,15 @@ def create_gridanchordynamic_plugin(layer_name, - base_size, - stride, - scales=np.array([1.]), - ratios=np.array([1.]), - scale_major=True, - center_x=-1, - center_y=-1, - base_anchors=None): + stride): creator = trt.get_plugin_registry().get_plugin_creator( 'GridAnchorDynamicPluginDynamic', '1', '') pfc = trt.PluginFieldCollection() - pf_base_size = trt.PluginField("base_size", - np.array([base_size], dtype=np.int32), - trt.PluginFieldType.INT32) - pfc.append(pf_base_size) - pf_stride = trt.PluginField("stride", np.array([stride], dtype=np.int32), trt.PluginFieldType.INT32) pfc.append(pf_stride) - pf_scales = trt.PluginField("scales", - np.array(scales).astype(np.float32), - trt.PluginFieldType.FLOAT32) - pfc.append(pf_scales) - - pf_ratios = trt.PluginField("ratios", - np.array(ratios).astype(np.float32), - trt.PluginFieldType.FLOAT32) - pfc.append(pf_ratios) - - pf_scale_major = trt.PluginField( - "scale_major", np.array([int(scale_major)], dtype=np.int32), - trt.PluginFieldType.INT32) - pfc.append(pf_scale_major) - - pf_center_x = trt.PluginField("center_x", - np.array([center_x], dtype=np.int32), - trt.PluginFieldType.INT32) - pfc.append(pf_center_x) - - pf_center_y = trt.PluginField("center_y", - np.array([center_y], dtype=np.int32), - trt.PluginFieldType.INT32) - pfc.append(pf_center_y) - - if base_anchors is not None: - pf_base_anchors = trt.PluginField( - "base_anchors", - np.array(base_anchors).astype(np.float32), - trt.PluginFieldType.FLOAT32) - pfc.append(pf_base_anchors) - return creator.create_plugin(layer_name, pfc)