diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 120c1b71be72..ea2d07903e71 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -816,6 +816,11 @@ TVM_DLL const Op& vectorlow(); */ TVM_DLL const Op& vectorcombine(); +/*! + * \brief Dot product of two int8x4 vectors and add an optional accumulator + */ +TVM_DLL const Op& dp4a(); + /*! * \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA */ diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index caefc6a6bc16..bdbd6e2cdac0 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1932,6 +1932,7 @@ def wrapped(*args, **kwargs): vectorhigh = _dtype_forward(_tir_op.vectorhigh) vectorcombine = _dtype_forward(_tir_op.vectorcombine) get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask) +dp4a = _dtype_forward(_tir_op.dp4a) broadcast = Broadcast @@ -2191,6 +2192,7 @@ def wrapped(*args, **kwargs): "vectorlow", "vectorhigh", "vectorcombine", + "dp4a", "assume", "undef", "tvm_call_packed", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 5360ab2b9697..bcfbe6575d52 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -95,6 +95,7 @@ from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic from .op import vscale, get_active_lane_mask, get_vscale_expr +from .op import dp4a from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 81d6604259a3..0bc299e403c5 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1813,6 +1813,31 @@ def vectorcombine(dtype, vec1, vec2): return call_intrin(dtype, "tir.vectorcombine", vec1, vec2) +def dp4a(vec1, vec2, acc=0): + """Dot product of two int8x4 vectors and add an optional accumulator + + Parameters + ---------- + vec1 : int8x4 + The input vector. + + vec2 : int8x4 + The input vector. + + acc : int32 + The accumulator. + + Returns + ------- + call : PrimExpr + The call expression. + """ + vec1 = convert(vec1) + vec2 = convert(vec2) + acc = convert(acc) + return call_intrin("int32", "tir.dp4a", vec1, vec2, acc) + + def ret(val): """Create a tir return expression diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0404fd28230e..0d4a213a23aa 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -355,6 +355,11 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine) .set_attr("TScriptDtypePrintLocation", Integer(ScriptDtypePrintLocation::kFirst)); +TIR_DEFINE_BUILTIN_FUNC(dp4a) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + TIR_DEFINE_BUILTIN_FUNC(atomic_add) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/tests/python/tir-base/test_tir_op_types.py b/tests/python/tir-base/test_tir_op_types.py index 7398ee781b9e..aefab62559c2 100644 --- a/tests/python/tir-base/test_tir_op_types.py +++ b/tests/python/tir-base/test_tir_op_types.py @@ -295,6 +295,14 @@ def test_tir_op_vectorhigh(): assert expr.op.name == "tir.vectorhigh" +def test_tir_op_dp4a(): + vec1 = tir.Var("vec1", dtype="int8x4") + vec2 = tir.Var("vec2", dtype="int8x4") + acc = tir.Var("acc", dtype="int32") + expr = tir.dp4a(vec1, vec2, acc) + assert expr.op.name == "tir.dp4a" + + def test_tir_op_vectorcombine(): buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1) vec = buffer.vload([0, 0], dtype="int8x16")