Skip to content

Commit

Permalink
optimized graph needs to preserve the topological order
Browse files Browse the repository at this point in the history
Signed-off-by: lmcl90 <[email protected]>
  • Loading branch information
lmcl90 committed Nov 12, 2024
1 parent b3a4611 commit 70aa4e6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
2 changes: 1 addition & 1 deletion onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct FuseMatMulAddBiasIntoGemm final : public PredicateBasedPass {
gemm->f_(kbeta, 1.0);
gemm->i_(ktransA, 0);
gemm->i_(ktransB, 0);
gemm->insertBefore(orig_matmul->node());
gemm->insertBefore(n);
const bool replacing_success = tryReplacingAllUsesWith(n, gemm);
if (!replacing_success) {
return false;
Expand Down
40 changes: 34 additions & 6 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,9 +1478,10 @@ def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (32, 16))],
)
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])
assert len(list(optimized_model.graph.node)) == 2
assert optimized_model.graph.node[0].op_type == "MatMul"
assert optimized_model.graph.node[1].op_type == "Gemm"

assert len(list(optimized_model.graph.node)) == 1
assert optimized_model.graph.node[0].op_type == "Gemm"

def test_fuse_matmul_add_bias_into_gemm_2d_bias(self): # type: () -> None
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
Expand All @@ -1497,8 +1498,9 @@ def test_fuse_matmul_add_bias_into_gemm_2d_bias(self): # type: () -> None
)
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])

assert len(list(optimized_model.graph.node)) == 1
assert optimized_model.graph.node[0].op_type == "Gemm"
assert len(list(optimized_model.graph.node)) == 2
assert optimized_model.graph.node[0].op_type == "MatMul"
assert optimized_model.graph.node[1].op_type == "Gemm"

# type: () -> None
def test_fuse_matmul_add_bias_into_gemm_2d_bias_same_shape(self):
Expand All @@ -1516,8 +1518,34 @@ def test_fuse_matmul_add_bias_into_gemm_2d_bias_same_shape(self):
)
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])

assert len(list(optimized_model.graph.node)) == 1
assert optimized_model.graph.node[0].op_type == "Gemm"
assert len(list(optimized_model.graph.node)) == 2
assert optimized_model.graph.node[0].op_type == "MatMul"
assert optimized_model.graph.node[1].op_type == "Gemm"

def test_fuse_matmul_add_bias_into_gemm_must_keep_topological_order(self):
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"], name='matmul')
bias = helper.make_node("MatMul", ["M", "N"], ["B"], name='bias')
add = helper.make_node("Add", ["Z", "B"], ["A"], name='add')
graph = helper.make_graph(
[matmul, bias, add],
"test",
[
helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 10)),
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16)),
helper.make_tensor_value_info("M", TensorProto.FLOAT, (32, 10)),
helper.make_tensor_value_info("N", TensorProto.FLOAT, (10, 16)),
],
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (32, 16))],
value_info=[
helper.make_tensor_value_info("B", TensorProto.FLOAT, (32, 16)),
]
)
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])

assert len(list(optimized_model.graph.node)) == 3
assert optimized_model.graph.node[0].op_type == "MatMul"
assert optimized_model.graph.node[1].op_type == "MatMul"
assert optimized_model.graph.node[2].op_type == "Gemm"

# type: () -> None
def test_fuse_matmul_add_bias_into_gemm_2d_bias_bcast_no_fuse(self):
Expand Down

0 comments on commit 70aa4e6

Please sign in to comment.