From ded269d9196ea45076034a15f2d4d13f6c6c6642 Mon Sep 17 00:00:00 2001 From: "wenyuchi.wyc" Date: Tue, 22 Nov 2022 11:33:45 +0800 Subject: [PATCH] Support fuse add into ConvTranspose. --- onnxoptimizer/passes/fuse_add_bias_into_conv.h | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/onnxoptimizer/passes/fuse_add_bias_into_conv.h b/onnxoptimizer/passes/fuse_add_bias_into_conv.h index 7b8726b0f..c3d8c6e19 100644 --- a/onnxoptimizer/passes/fuse_add_bias_into_conv.h +++ b/onnxoptimizer/passes/fuse_add_bias_into_conv.h @@ -32,10 +32,21 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass { std::string getPassName() const override { return "fuse_add_bias_into_conv"; } - bool patternMatchPredicate(Node *node) override { + + inline bool matchConvAdd(Node *node) { return node->kind() == kAdd && node->inputs()[0]->node()->kind() == kConv && node->inputs()[0]->node()->inputs().size() == 2; } + + inline bool matchConvTransposeAdd(Node *node) { + return node->kind() == kAdd && node->inputs()[0]->node()->kind() == kConvTranspose && + node->inputs()[0]->node()->inputs().size() == 2; + } + + bool patternMatchPredicate(Node *node) override { + return matchConvAdd(node) || matchConvTransposeAdd(node); + } + static Node *makeSqueezeOrUnsqueeze(Graph &graph, std::vector &axes, Value *input, Node *target_node, BuiltinSymbol k) { @@ -61,6 +72,7 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass { NodeDestroyType &destroy_current) override { // due to current broadcasting's constraint, Conv has to be the first // operand + const bool is_conv = matchConvAdd(n); destroy_current = NodeDestroyType::DestroyZero; auto orig_conv = n->inputs()[0]; auto orig_bias = n->inputs()[1]; @@ -85,8 +97,8 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass { } // try to get feature M and rank from weight_shape if (weight_shape.size() > 0 && weight_shape[0].is_int) { - ONNX_ASSERT(M == -1 || M == weight_shape[0].dim); - M = weight_shape[0].dim; + ONNX_ASSERT(M == -1 || M == weight_shape[0].dim || M == weight_shape[1].dim); + M = is_conv ? weight_shape[0].dim : weight_shape[1].dim; ONNX_ASSERT(rank == -1 || rank == static_cast(weight_shape.size())); rank = weight_shape.size();