Skip to content

Commit

Permalink
[CODEGEN] Vector-Codegen support for llvm-pure-intrin (#16985)
Browse files Browse the repository at this point in the history
* Vector-Codegen support for llvm-pure-intrin
  • Loading branch information
rutkoor authored Jun 4, 2024
1 parent f5d3fc2 commit 78a1f80
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));
Integer(ScriptDtypePrintLocation::kFirst))
.set_attr<TVectorizable>("TVectorizable", true);

TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
Expand Down
23 changes: 22 additions & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
} else {
int lane = 0;
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
Array<PrimExpr> new_args;
if (op->op.same_as(builtin::call_llvm_pure_intrin())) {
// op->args[1], will give us total number of arguments to intrinsic
int num_signature = Downcast<IntImm>(op->args[1])->value;
Array<PrimExpr> op_expr_args;
for (int i = 0; i < num_signature; i++) {
// Collect all intrinsic arguments
op_expr_args.push_back(op->args[i + 2]);
}
// Generate RAMP nodes for intrinsic arguments
Array<PrimExpr> updated_args = MutateArray(op_expr_args, &lane);
// Collect Intrinsic ID and no. of argument
for (int i = 0; i < 2; i++) {
new_args.push_back(op->args[i]);
}
// Collect updated intrinsic arguments
for (int i = 0; i < num_signature; i++) {
new_args.push_back(updated_args[i]);
}
} else {
new_args = MutateArray(op->args, &lane);
}
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
Expand Down
58 changes: 58 additions & 0 deletions tests/python/tir-transform/test_tir_transform_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,5 +790,63 @@ def expected(a: T.handle, b: T.handle):
tvm.ir.assert_structural_equal(after, expected)


@pytest.mark.parametrize(
"extent, vec_str, target",
[(4, "float32x4", simple_target)],
)
def test_vectorize_llvm_pure_intrin(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.call_llvm_pure_intrin(
"float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j]
)

@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)]
)

with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
mod = tvm.build(mod, target)


@pytest.mark.parametrize(
"extent, vec_str, target",
[(4, "int32x4", simple_target)],
)
def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.call_llvm_pure_intrin(
"int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j]
)

@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)]
)

with pytest.raises(Exception) as e_info:
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
ex = tvm.build(mod, target)
tvm.ir.assert_structural_equal(mod, After)
assert "Intrinsic does not support vectors" in e_info.value.args[0]


if __name__ == "__main__":
tvm.testing.main()
21 changes: 21 additions & 0 deletions tests/python/tvmscript/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,5 +1045,26 @@ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")):
_assert_print(main, expected_output)


def test_vectorize_llvm_pure_intrin():
from tvm.script import tir as T

@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (4,), "float32")
B = T.match_buffer(b, (4,), "float32")
A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin(
"float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)]
)

expected_output = """
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4])
"""
_assert_print(main, expected_output)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 78a1f80

Please sign in to comment.