From ff1229e4ae3a61e65ef7faf416b619cb91739005 Mon Sep 17 00:00:00 2001 From: "wenyuchi.wyc" Date: Tue, 22 Nov 2022 11:37:22 +0800 Subject: [PATCH] Support fuse bn into ConvTranspose. Signed-off-by: wenyuchi.wyc --- onnxoptimizer/passes/fuse_bn_into_conv.h | 16 +++++----- onnxoptimizer/test/optimizer_test.py | 40 ++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/onnxoptimizer/passes/fuse_bn_into_conv.h b/onnxoptimizer/passes/fuse_bn_into_conv.h index 0ff711d4d..b0353ea99 100644 --- a/onnxoptimizer/passes/fuse_bn_into_conv.h +++ b/onnxoptimizer/passes/fuse_bn_into_conv.h @@ -47,7 +47,7 @@ struct FuseBNIntoConv final : public PredicateBasedPass { return "fuse_bn_into_conv"; } - bool modify_conv(Node* conv, Node* bn, Graph& graph) { + bool modify_conv(Node* conv, Node* bn, Graph& graph, const bool is_conv) { const auto& bn_inputs = bn->inputs(); const auto& conv_inputs = conv->inputs(); @@ -123,10 +123,9 @@ struct FuseBNIntoConv final : public PredicateBasedPass { Node* unsqueeze = graph.create(kUnsqueeze, 1); unsqueeze->insertAfter(scale); unsqueeze->addInput(scale->output()); - std::vector insert_dims; - for (int i = 1; i < conv_W.sizes().size(); ++i) { - insert_dims.push_back(i); - } + std::vector insert_dims(conv_W.sizes().size()); + std::iota(insert_dims.begin(), insert_dims.end(), 0); + insert_dims.erase(insert_dims.begin() + (is_conv ? 0 : 1)); if (getOpsetVersion(graph) > 11) { Tensor shape_s_t; shape_s_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64; @@ -181,7 +180,8 @@ struct FuseBNIntoConv final : public PredicateBasedPass { } bool patternMatchPredicate(Node* n) override { - return CheckKind(n, kBatchNormalization, 0, kConv) && + return (CheckKind(n, kBatchNormalization, 0, kConv) || + CheckKind(n, kBatchNormalization, 0, kConvTranspose)) && GetValueFromAttrWithDefault(n, "training_mode", (int64_t)0) == 0 && n->input(0)->uses().size() == 1 && n->outputs().size() == 1 && IsConstantTensor(n, 1) && IsConstantTensor(n, 2) && @@ -190,10 +190,12 @@ struct FuseBNIntoConv final : public PredicateBasedPass { } bool runTransform(Node* n, Graph& graph, NodeDestroyType& destroy_current) override { + const bool is_conv = CheckKind(n, kBatchNormalization, 0, kConv); + Node* bn = n; Node* conv = PrevNode(n, 0); auto origInput = bn->inputs()[0]; - if (!modify_conv(conv, bn, graph)) { + if (!modify_conv(conv, bn, graph, is_conv)) { destroy_current = NodeDestroyType::DestroyZero; return false; } diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index ad7721800..2c180154c 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -3063,6 +3063,46 @@ def test_fuse_bn_into_conv_simple(self): # type: () -> None ) optimized_model = self._optimized(graph, ["fuse_bn_into_conv"]) # noqa + def test_fuse_bn_into_conv_transpose_simple(self): # type: () -> None + for (tensor_type, np_type) in [(TensorProto.FLOAT, np.float32)]: + conv = helper.make_node("ConvTranspose", ["X", "W", "B"], ["Y"], strides=(2, 2)) + bn = helper.make_node( + "BatchNormalization", ["Y", "scale", "b", "mean", "var"], ["Z"] + ) + + W = np.random.randn(64, 64, 2, 2).astype(np_type) + 2 + B = np.random.randn(64,).astype(np_type) + 2 + scale = np.random.randn(64,).astype(np_type) + 2 + b = np.random.randn(64,).astype(np_type) + 2 + mean = np.random.randn(64,).astype(np_type) + 2 + var = np.abs(np.random.randn(64,).astype(np_type)) + 2 + + initializers = [ + helper.make_tensor( + name, tensor_type, npa.shape, npa.tobytes(), raw=True + ) + for name, npa in [ + ("W", W), + ("B", B), + ("scale", scale), + ("b", b), + ("mean", mean), + ("var", var), + ] + ] + graph = helper.make_graph( + [conv, bn], + "test", + [helper.make_tensor_value_info("X", tensor_type, (1, 64, 160, 160))], + [helper.make_tensor_value_info("Z", tensor_type, (1, 64, 320, 320))], + initializer=initializers, + value_info=[ + helper.make_tensor_value_info("Y", tensor_type, (1, 64, 320, 320)) + ], + ) + + optimized_model = self._optimized(graph, ["fuse_bn_into_conv"]) + def _internal_test_deadend_elimination(self, fixed): # type: (bool) -> None softmax = helper.make_node("Softmax", ["X"], ["Y"], axis=2) log = helper.make_node("Log", ["Y"], ["Z"])