From bd3d9a97368eaf421ee0aa1a483690ae5dffc5dc Mon Sep 17 00:00:00 2001 From: Lingxiao Ma Date: Thu, 20 May 2021 16:55:16 +0800 Subject: [PATCH] Fix DepthWiseConv2dNative (#267) * fix shape inference of generic_op_define/DepthwiseConv2dNative * onnx frontend support for DepthwiseConv2dNative op * fix bug in TensorFlow frontend for DepthwiseConv2dNative op * fix bug in onnx frontend for DepthwiseConv2dNative Co-authored-by: Lingxiao Ma --- .../DepthwiseConv2dNative.cpp | 23 ++-- src/nnfusion/frontend/onnx_import/op/conv.cpp | 125 +++++++++++++----- .../tensorflow_import/util/graph_convert.cpp | 2 +- 3 files changed, 104 insertions(+), 46 deletions(-) diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/DepthwiseConv2dNative.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/DepthwiseConv2dNative.cpp index 1d88548d1..7ebd95413 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/DepthwiseConv2dNative.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/DepthwiseConv2dNative.cpp @@ -10,13 +10,6 @@ REGISTER_OP(DepthwiseConv2dNative) .attr("dilations") .attr("padding_before") .attr("padding_after") - .constrait([](const nnfusion::op::OpConfig::any& config) -> bool { - if (config["padding_type"] != "SAME") - { - return false; - } - return true; - }) .infershape([](std::shared_ptr gnode) -> void { NNFUSION_CHECK(gnode->get_input_size() == 2); auto op = std::dynamic_pointer_cast(gnode->get_op_ptr()); @@ -39,11 +32,23 @@ REGISTER_OP(DepthwiseConv2dNative) const int64_t filter_rows = filter_shape[0]; const int64_t filter_cols = filter_shape[1]; const int64_t batch = input_shape[0]; + auto padding_before = op->localOpConfig.getRoot()["padding_before"]; + auto padding_after = op->localOpConfig.getRoot()["padding_after"]; + const int64_t padding_h = padding_before[0]; + const int64_t padding_w = padding_before[1]; + const int64_t dilation_h = op->localOpConfig.getRoot()["dilations"][0]; + const int64_t dilation_w = op->localOpConfig.getRoot()["dilations"][1]; std::vector strides = op->localOpConfig.getRoot()["strides"]; NNFUSION_CHECK(strides.size() == 2); - const int64_t out_rows = (input_rows + strides[0] - 1) / strides[0]; - const int64_t out_cols = (input_cols + strides[1] - 1) / strides[1]; + const int64_t out_rows = + (int64_t)((float)(input_rows + 2 * padding_h - dilation_h * (filter_rows - 1) - 1) / + (float)strides[0] + + 1); + const int64_t out_cols = + (int64_t)((float)(input_cols + 2 * padding_w - dilation_w * (filter_cols - 1) - 1) / + (float)strides[1] + + 1); Shape output_shape( is_nhwc diff --git a/src/nnfusion/frontend/onnx_import/op/conv.cpp b/src/nnfusion/frontend/onnx_import/op/conv.cpp index 21ea25223..7d089e19c 100644 --- a/src/nnfusion/frontend/onnx_import/op/conv.cpp +++ b/src/nnfusion/frontend/onnx_import/op/conv.cpp @@ -22,6 +22,7 @@ #include "conv.hpp" #include +#include "nnfusion/core/operators/generic_op/generic_op.hpp" #include "nnfusion/frontend/onnx_import/util/broadcasting.hpp" namespace nnfusion @@ -143,44 +144,96 @@ namespace nnfusion // split data and filters for group conv std::size_t n_data_channels{data_shape.at(1)}; std::size_t n_filters_channels{filters_shape.at(0)}; - NNFUSION_CHECK(n_data_channels % groups == 0 && - n_filters_channels & groups == 0); - std::size_t data_group_size{n_data_channels / groups}; - std::size_t filters_group_size{n_filters_channels / groups}; - - std::vector data_lower_bounds(data_shape.size(), 0); - std::vector data_upper_bounds{data_shape}; - std::vector filters_lower_bounds(filters_shape.size(), 0); - std::vector filters_upper_bounds{filters_shape}; - - std::vector> convolution_nodes; - for (std::size_t group = 0; group < groups; ++group) + if (n_data_channels == groups) { - // slice data - data_lower_bounds[1] = group * data_group_size; - data_upper_bounds[1] = (group + 1) * data_group_size; - auto sliced_data_op = - std::make_shared(data_lower_bounds, data_upper_bounds); - auto sliced_data = m_graph->add_node_and_edge(sliced_data_op, {data}); - // slice filters - filters_lower_bounds[0] = group * filters_group_size; - filters_upper_bounds[0] = (group + 1) * filters_group_size; - auto sliced_filters_op = std::make_shared( - filters_lower_bounds, filters_upper_bounds); - auto sliced_filters = - m_graph->add_node_and_edge(sliced_filters_op, {filters}); - - convolution_nodes.push_back(m_graph->add_node_and_edge( - std::make_shared(strides, - dilations, - padding_below, - padding_above, - conv_data_format), - {sliced_data, sliced_filters})); + // depthwise convolution + NNFUSION_CHECK(n_filters_channels == groups) + << "Currently only support depth_multiplier = 1 in DepthwiseConv2d"; + + auto filter_gnode = GetInputNode(all_ng_nodes, node_proto, 1); + auto reshape_filter_op = std::make_shared( + nnfusion::AxisVector{2, 3, 0, 1}, + nnfusion::Shape({filters_shape[2], + filters_shape[3], + filters_shape[0], + filters_shape[1]})); + reshape_filter_op->set_name(filter_gnode->get_name() + + "_filters_reshape"); + auto reshape_filter_gnode = + m_graph->add_node_and_edge(reshape_filter_op, {filter_gnode}); + + size_t depth_multiplier = 1; + nnfusion::op::OpConfig::any myConfig; + myConfig["data_format"] = "NCHW"; + myConfig["strides"] = strides; + myConfig["dilations"] = dilations; + myConfig["padding_before"] = padding_below; + myConfig["padding_after"] = padding_above; + + if ((2 * padding_below[0] - dilations[0] * (filters_shape[2] - 1) == + 0) && + (2 * padding_below[1] - dilations[1] * (filters_shape[3] - 1) == 0)) + { + myConfig["padding_type"] = "SAME"; + } + else if (padding_below[0] == 0 && padding_below[1] == 0) + { + myConfig["padding_type"] = "VALID"; + } + else + { + NNFUSION_CHECK_FAIL() << "Currently only support SAME and VALID " + "padding in DepthwiseConv2dNative"; + } + + auto conv_op = std::make_shared( + node_proto.name(), "DepthwiseConv2dNative", myConfig); + conv_node = m_graph->add_node_and_edge( + conv_op, {data, GNodeIndex{reshape_filter_gnode, 0}}); + } + else + { + NNFUSION_CHECK(n_data_channels % groups == 0 && + n_filters_channels & groups == 0); + std::size_t data_group_size{n_data_channels / groups}; + std::size_t filters_group_size{n_filters_channels / groups}; + + std::vector data_lower_bounds(data_shape.size(), 0); + std::vector data_upper_bounds{data_shape}; + std::vector filters_lower_bounds(filters_shape.size(), 0); + std::vector filters_upper_bounds{filters_shape}; + + std::vector> convolution_nodes; + for (std::size_t group = 0; group < groups; ++group) + { + // slice data + data_lower_bounds[1] = group * data_group_size; + data_upper_bounds[1] = (group + 1) * data_group_size; + auto sliced_data_op = std::make_shared( + data_lower_bounds, data_upper_bounds); + auto sliced_data = + m_graph->add_node_and_edge(sliced_data_op, {data}); + // slice filters + filters_lower_bounds[0] = group * filters_group_size; + filters_upper_bounds[0] = (group + 1) * filters_group_size; + auto sliced_filters_op = std::make_shared( + filters_lower_bounds, filters_upper_bounds); + auto sliced_filters = + m_graph->add_node_and_edge(sliced_filters_op, {filters}); + + convolution_nodes.push_back(m_graph->add_node_and_edge( + std::make_shared(strides, + dilations, + padding_below, + padding_above, + conv_data_format), + {sliced_data, sliced_filters})); + } + std::size_t concatenation_axis = 1; + conv_node = m_graph->add_node_and_edge( + std::make_shared(concatenation_axis), + convolution_nodes); } - std::size_t concatenation_axis = 1; - conv_node = m_graph->add_node_and_edge( - std::make_shared(concatenation_axis), convolution_nodes); } // add bias diff --git a/src/nnfusion/frontend/tensorflow_import/util/graph_convert.cpp b/src/nnfusion/frontend/tensorflow_import/util/graph_convert.cpp index 7d8d6f64a..5c1ca4fb4 100644 --- a/src/nnfusion/frontend/tensorflow_import/util/graph_convert.cpp +++ b/src/nnfusion/frontend/tensorflow_import/util/graph_convert.cpp @@ -869,7 +869,7 @@ namespace nnfusion ng_padding_below, ng_padding_above); - NNFUSION_CHECK(ng_padding_above == ng_padding_above) + NNFUSION_CHECK(ng_padding_below == ng_padding_above) << "Asymetric padding is not supported by now."; nnfusion::op::OpConfig::any op_config; op_config["data_format"] = tf_data_format;