diff --git a/onnxoptimizer/passes/fuse_bn_into_conv.h b/onnxoptimizer/passes/fuse_bn_into_conv.h index d1d8f4a09..cd4f7ef12 100644 --- a/onnxoptimizer/passes/fuse_bn_into_conv.h +++ b/onnxoptimizer/passes/fuse_bn_into_conv.h @@ -69,7 +69,42 @@ struct FuseBNIntoConv final : public PredicateBasedPass { } } - bool modify_conv(Node* conv, Node* bn, Graph& graph) { + void scale_by_dim(Tensor& W, Tensor& s, const int axis) { + ONNX_ASSERT(W.sizes().size() > 1 && s.sizes().size() == 1 && s.sizes()[0] == W.sizes()[axis]); + ONNX_ASSERT(s.elem_type() == W.elem_type()); + const int64_t inner_size = W.size_from_dim(axis+1); + const int64_t outer_size = axis > 0 ? std::accumulate(W.sizes().begin(), W.sizes().begin() + axis, 1, std::multiplies()) : 1; + const int64_t axis_size = W.sizes()[axis]; + +#define DO_SCALE(TENSOR_TYPE) \ + TENSOR_TYPE* ptr = W.data(); \ + const TENSOR_TYPE* s_ptr = s.data(); \ + int64_t counter = 0; \ + for (int64_t i = 0; i < outer_size; ++i) { \ + for (int64_t j = 0; j < axis_size; ++j) { \ + for (int64_t k = 0; k < inner_size; ++k) { \ + ptr[counter++] *= s_ptr[j]; \ + } \ + } \ + } + + switch (s.elem_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + DO_SCALE(float) + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { + DO_SCALE(double) + break; + } + default: + TENSOR_ASSERTM( + false, "Operation scale_by_dim not supported for data type %s", to_string(W.elem_type()).c_str()); + } +#undef DO_SCALE + } + + bool modify_conv(Node* conv, Node* bn, Graph& graph, const bool is_conv) { const auto& bn_inputs = bn->inputs(); const auto& conv_inputs = conv->inputs(); auto end_iter = graph.initializers().end(); @@ -136,7 +171,6 @@ struct FuseBNIntoConv final : public PredicateBasedPass { var.add(eps); \ var.sqrt(); \ s.divide(var); \ - W.scale_by_first_dim(s); \ bc.subtract(m); \ bc.multiply(s); \ bc.add(bbn); @@ -154,21 +188,38 @@ struct FuseBNIntoConv final : public PredicateBasedPass { return false; } #undef DO_COMPUTATION + if (is_conv) { + scale_by_dim(W, s, 0); + } else { + scale_by_dim(W, s, 1); + } replace_inputs(W, bc, conv, graph); return true; } - bool patternMatchPredicate(Node* node) override { + inline bool matchConvBn(Node *node) { return node->kind() == kBatchNormalization && node->inputs()[0]->node()->kind() == kConv; } + + inline bool matchConvTransposeBn(Node *node) { + return node->kind() == kBatchNormalization && + node->inputs()[0]->node()->kind() == kConvTranspose; + } + + bool patternMatchPredicate(Node *node) override { + return matchConvBn(node) || matchConvTransposeBn(node); + } + bool runTransform(Node* n, Graph& graph, NodeDestroyType& destroy_current) override { + const bool is_conv = matchConvBn(n); + Node* bn = n; Node* conv = n->inputs()[0]->node(); auto origInput = bn->inputs()[0]; if (origInput->uses().size() > 1 || bn->outputs().size() > 1 || - !modify_conv(conv, bn, graph)) { + !modify_conv(conv, bn, graph, is_conv)) { destroy_current = NodeDestroyType::DestroyZero; return false; }