diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py index 1d917c59523b..05bbe429bbcc 100644 --- a/tests/python/relax/test_dataflow_rewriter.py +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -257,12 +257,8 @@ def before(x: R.Tensor([16], "float32")): @R.function(private=True) def expected(x: R.Tensor([16], "float32")): - y = R.call_pure_packed( - "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") - ) - z = R.call_pure_packed( - "my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32") - ) + y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")) + z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32")) return z after = Rewriter(before) @@ -316,12 +312,8 @@ def expected( B: R.Tensor([16], "float32"), C: R.Tensor([16], "float32"), ): - D = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) - E = R.call_pure_packed( - "my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32") - ) + D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")) + E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32")) return E rewriter = RewriteAdd | RewriteMultiply @@ -457,9 +449,7 @@ def pattern(A: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32")): - return R.call_tir( - RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32") - ) + return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32")) @T.prim_func(private=True) def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): @@ -537,9 +527,7 @@ def pattern(A: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32")): - return R.call_tir( - RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32") - ) + return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32")) @T.prim_func(private=True) def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): @@ -559,9 +547,7 @@ class Expected: @R.function def main(A: R.Tensor([16], "float32")): B = Expected.subroutine(A) - C = R.call_tir( - Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32") - ) + C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32")) return C @R.function(private=True) @@ -1212,9 +1198,7 @@ def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): ) @R.function(private=True) - def before( - A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool") - ): + def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): if cond: out = A + B else: @@ -1223,9 +1207,7 @@ def before( return out @R.function(private=True) - def expected( - A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool") - ): + def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): if cond: out = R.call_pure_packed( "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 3255889960d4..ea3b1c249b8b 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -352,9 +352,7 @@ def main( # The symbolic shapes propagate downstream. lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) - proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul( - lhs, rhs, out_dtype="void" - ) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void") proj_A = R.strided_slice( proj_concat, (R.prim_value(0),), @@ -384,9 +382,7 @@ def main( # statically-known shapes. lhs: R.Tensor([32, 16], dtype="float32") = R.concat((A, B), axis=0) - proj_concat: R.Tensor([32], dtype="float32") = R.matmul( - lhs, state, out_dtype="void" - ) + proj_concat: R.Tensor([32], dtype="float32") = R.matmul(lhs, state, out_dtype="void") proj_A: R.Tensor([16], dtype="float32") = R.strided_slice( proj_concat, [R.prim_value(0)], @@ -435,9 +431,7 @@ def main( rhs = R.match_cast(state, R.Tensor([M], dtype="float32")) lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) - proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul( - lhs, rhs, out_dtype="void" - ) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void") proj_A = R.strided_slice( proj_concat, (R.prim_value(0),),