Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Jul 10, 2024
1 parent 7488ea1 commit 28a72e9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 36 deletions.
36 changes: 9 additions & 27 deletions tests/python/relax/test_dataflow_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,8 @@ def before(x: R.Tensor([16], "float32")):

@R.function(private=True)
def expected(x: R.Tensor([16], "float32")):
y = R.call_pure_packed(
"my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")
)
z = R.call_pure_packed(
"my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32")
)
y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32"))
z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32"))
return z

after = Rewriter(before)
Expand Down Expand Up @@ -316,12 +312,8 @@ def expected(
B: R.Tensor([16], "float32"),
C: R.Tensor([16], "float32"),
):
D = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
E = R.call_pure_packed(
"my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32")
)
D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32"))
E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32"))
return E

rewriter = RewriteAdd | RewriteMultiply
Expand Down Expand Up @@ -457,9 +449,7 @@ def pattern(A: R.Tensor([16], "float32")):

@R.function
def replacement(A: R.Tensor([16], "float32")):
return R.call_tir(
RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32")
)
return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32"))

@T.prim_func(private=True)
def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expand Down Expand Up @@ -537,9 +527,7 @@ def pattern(A: R.Tensor([16], "float32")):

@R.function
def replacement(A: R.Tensor([16], "float32")):
return R.call_tir(
RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32")
)
return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32"))

@T.prim_func(private=True)
def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expand All @@ -559,9 +547,7 @@ class Expected:
@R.function
def main(A: R.Tensor([16], "float32")):
B = Expected.subroutine(A)
C = R.call_tir(
Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32")
)
C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32"))
return C

@R.function(private=True)
Expand Down Expand Up @@ -1212,9 +1198,7 @@ def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
)

@R.function(private=True)
def before(
A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")
):
def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")):
if cond:
out = A + B
else:
Expand All @@ -1223,9 +1207,7 @@ def before(
return out

@R.function(private=True)
def expected(
A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")
):
def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")):
if cond:
out = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
Expand Down
12 changes: 3 additions & 9 deletions tests/python/relax/test_transform_canonicalize_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,7 @@ def main(

# The symbolic shapes propagate downstream.
lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0)
proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(
lhs, rhs, out_dtype="void"
)
proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void")
proj_A = R.strided_slice(
proj_concat,
(R.prim_value(0),),
Expand Down Expand Up @@ -384,9 +382,7 @@ def main(
# statically-known shapes.

lhs: R.Tensor([32, 16], dtype="float32") = R.concat((A, B), axis=0)
proj_concat: R.Tensor([32], dtype="float32") = R.matmul(
lhs, state, out_dtype="void"
)
proj_concat: R.Tensor([32], dtype="float32") = R.matmul(lhs, state, out_dtype="void")
proj_A: R.Tensor([16], dtype="float32") = R.strided_slice(
proj_concat,
[R.prim_value(0)],
Expand Down Expand Up @@ -435,9 +431,7 @@ def main(
rhs = R.match_cast(state, R.Tensor([M], dtype="float32"))

lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0)
proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(
lhs, rhs, out_dtype="void"
)
proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void")
proj_A = R.strided_slice(
proj_concat,
(R.prim_value(0),),
Expand Down

0 comments on commit 28a72e9

Please sign in to comment.