diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index ce1c5986e1cad..2bcb8563a2940 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -445,10 +445,7 @@ def apply( UNROLL = 256 SUPPORT_WARP_SHUFFLE = True if isinstance(len_S, int): - if len_S > len_R: - TS, TR = 4, 64 - else: - TS, TR = 16, 32 + TS, TR = 16, 32 else: TS, TR = 1, 64 elif target.kind.name == "metal": diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index adcb05839bc91..c22f3f01a8805 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -33,6 +33,13 @@ from tvm.target import Target from tvm.topi.utils import get_const_tuple +from tvm.script import ir as I, relax as R, tir as T + +try: + import ml_dtypes +except ImportError: + ml_dtypes = None + @tvm.testing.requires_cuda_compute_version(9) def test_e4m3_conversions(): @@ -814,5 +821,119 @@ def func(A: T.Buffer((4,), dtype)) -> None: tvm.build(mod, target="cuda") +num_experts = 8 +reduce_size = 1792 +spatial_size = 4096 + + +@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes to be installed") +def test_moe_gemv_shfl_down_illegal_instr(): + global num_experts + global reduce_size + global spatial_size + + @I.ir_module + class SingleBatchMoE_float8_e4m3: + @T.prim_func(private=True) + def moe_dequantize_gemv( + x_handle: T.handle, + w: T.Buffer((num_experts, spatial_size, reduce_size), "e4m3_float8"), + scale: T.Buffer((1,), "float16"), + indptr: T.Buffer((1, 2), "int32"), + o: T.Buffer((2, spatial_size), "float16"), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) + num_seq = T.int64() + x = T.match_buffer(x_handle, (num_seq, reduce_size), "float16") + for expert_id in T.thread_binding(2, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(2, expert_id) + T.reads( + w[indptr[0, e], 0:spatial_size, 0:reduce_size], + indptr[0, e], + scale[0], + x[e, 0:reduce_size], + ) + T.writes(o[e, 0:spatial_size]) + y = T.alloc_buffer((spatial_size, reduce_size), "float16") + for i1, i2 in T.grid(spatial_size, reduce_size): + with T.block("dequantize"): + i, j = T.axis.remap("SS", [i1, i2]) + T.reads(w[indptr[0, e], i, j], indptr[0, e], scale[0]) + T.writes(y[i, j]) + y[i, j] = T.Cast("float16", w[indptr[0, e], i, j]) * scale[0] + for i1, i2 in T.grid(spatial_size, reduce_size): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + T.reads(x[e, j], y[i, j]) + T.writes(o[e, i]) + with T.init(): + o[e, i] = T.float16(0) + o[e, i] = o[e, i] + x[e, j] * y[i, j] + + @R.function + def main( + x: R.Tensor(("num_seq", reduce_size), dtype="float16"), + indptr: R.Tensor((1, 2), dtype="int32"), + weight: R.Tensor((num_experts, spatial_size, reduce_size), dtype="e4m3_float8"), + scale: R.Tensor((1,), dtype="float32"), + ) -> R.Tensor((2, spatial_size), dtype="float16"): + num_seq = T.int64() + R.func_attr({"num_input": 2}) + cls = SingleBatchMoE_float8_e4m3 + with R.dataflow(): + astype: R.Tensor((1,), dtype="float16") = R.astype(scale, dtype="float16") + lv = R.call_tir( + cls.moe_dequantize_gemv, + (x, weight, astype, indptr), + out_sinfo=R.Tensor((2, spatial_size), dtype="float16"), + ) + gv: R.Tensor((2, spatial_size), dtype="float16") = lv + R.output(gv) + return gv + + def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: + seq = tvm.transform.Sequential( + [ + tvm.relax.transform.LegalizeOps(), + tvm.dlight.ApplyDefaultSchedule( + tvm.dlight.gpu.Matmul(), + tvm.dlight.gpu.GEMV(), + tvm.dlight.gpu.Reduction(), + tvm.dlight.gpu.GeneralReduction(), + tvm.dlight.gpu.Fallback(), + ), + ] + ) + mod = seq(mod) + return mod + + mod = SingleBatchMoE_float8_e4m3 + + target = tvm.target.Target("cuda") + with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": False}) and target: + mod = _pipeline(mod) + rt_mod = tvm.relax.build(mod, target=target) + dev = tvm.cuda(0) + + x_data = np.zeros((1, reduce_size), dtype=np.float16) + x = tvm.nd.array(x_data, device=dev) + + indptr_data = np.zeros((1, 2), dtype=np.int32) + indptr = tvm.nd.array(indptr_data, device=dev) + + weight_data = np.zeros((num_experts, spatial_size, reduce_size), dtype="float8_e4m3fn") + weight = tvm.nd.array(weight_data, device=dev) + + scale_data = np.zeros((1,), dtype=np.float32) + scale = tvm.nd.array(scale_data, device=dev) + + vm = relax.VirtualMachine(rt_mod, dev) + # Ensure this runs without failure. Utilizing dlight thread extents TS, TR = 4, 64 + # in GEMV scheduling will yield: CUDA: an illegal instruction was encountered. + vm["main"](x, indptr, weight, scale) + + if __name__ == "__main__": tvm.testing.main()