Skip to content

Commit

Permalink
Adding test in test_tvmscript_printer_tir.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rutkoor committed Jun 3, 2024
1 parent 9027b21 commit 39d2838
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/python/tvmscript/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,5 +1045,26 @@ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")):
_assert_print(main, expected_output)


def test_vectorize_llvm_pure_intrin():
from tvm.script import tir as T

@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (4,), "float32")
B = T.match_buffer(b, (4,), "float32")
A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin(
"float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)]
)

expected_output = """
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4])
"""
_assert_print(main, expected_output)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 39d2838

Please sign in to comment.