diff --git a/onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h b/onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h index f3e7b6d8b..94d780533 100644 --- a/onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h +++ b/onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h @@ -93,7 +93,7 @@ struct FuseMatMulAddBiasIntoGemm final : public PredicateBasedPass { gemm->f_(kbeta, 1.0); gemm->i_(ktransA, 0); gemm->i_(ktransB, 0); - gemm->insertBefore(orig_matmul->node()); + gemm->insertBefore(n); const bool replacing_success = tryReplacingAllUsesWith(n, gemm); if (!replacing_success) { return false; diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index 5cd6b32fd..ca969d4d2 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -1478,9 +1478,10 @@ def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None [helper.make_tensor_value_info("A", TensorProto.FLOAT, (32, 16))], ) optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"]) + assert len(list(optimized_model.graph.node)) == 2 + assert optimized_model.graph.node[0].op_type == "MatMul" + assert optimized_model.graph.node[1].op_type == "Gemm" - assert len(list(optimized_model.graph.node)) == 1 - assert optimized_model.graph.node[0].op_type == "Gemm" def test_fuse_matmul_add_bias_into_gemm_2d_bias(self): # type: () -> None matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"]) @@ -1497,8 +1498,9 @@ def test_fuse_matmul_add_bias_into_gemm_2d_bias(self): # type: () -> None ) optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"]) - assert len(list(optimized_model.graph.node)) == 1 - assert optimized_model.graph.node[0].op_type == "Gemm" + assert len(list(optimized_model.graph.node)) == 2 + assert optimized_model.graph.node[0].op_type == "MatMul" + assert optimized_model.graph.node[1].op_type == "Gemm" # type: () -> None def test_fuse_matmul_add_bias_into_gemm_2d_bias_same_shape(self): @@ -1516,8 +1518,34 @@ def test_fuse_matmul_add_bias_into_gemm_2d_bias_same_shape(self): ) optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"]) - assert len(list(optimized_model.graph.node)) == 1 - assert optimized_model.graph.node[0].op_type == "Gemm" + assert len(list(optimized_model.graph.node)) == 2 + assert optimized_model.graph.node[0].op_type == "MatMul" + assert optimized_model.graph.node[1].op_type == "Gemm" + + def test_fuse_matmul_add_bias_into_gemm_must_keep_topological_order(self): + matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"], name='matmul') + bias = helper.make_node("MatMul", ["M", "N"], ["B"], name='bias') + add = helper.make_node("Add", ["Z", "B"], ["A"], name='add') + graph = helper.make_graph( + [matmul, bias, add], + "test", + [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 10)), + helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16)), + helper.make_tensor_value_info("M", TensorProto.FLOAT, (32, 10)), + helper.make_tensor_value_info("N", TensorProto.FLOAT, (10, 16)), + ], + [helper.make_tensor_value_info("A", TensorProto.FLOAT, (32, 16))], + value_info=[ + helper.make_tensor_value_info("B", TensorProto.FLOAT, (32, 16)), + ] + ) + optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"]) + + assert len(list(optimized_model.graph.node)) == 3 + assert optimized_model.graph.node[0].op_type == "MatMul" + assert optimized_model.graph.node[1].op_type == "MatMul" + assert optimized_model.graph.node[2].op_type == "Gemm" # type: () -> None def test_fuse_matmul_add_bias_into_gemm_2d_bias_bcast_no_fuse(self):