Skip to content

Commit

Permalink
[Dlight] Always use 16x32 spatial x reduction thread extents.
Browse files Browse the repository at this point in the history
  • Loading branch information
csullivan committed Jun 11, 2024
1 parent ab02979 commit be96146
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 4 deletions.
5 changes: 1 addition & 4 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,7 @@ def apply(
UNROLL = 256
SUPPORT_WARP_SHUFFLE = True
if isinstance(len_S, int):
if len_S > len_R:
TS, TR = 4, 64
else:
TS, TR = 16, 32
TS, TR = 16, 32
else:
TS, TR = 1, 64
elif target.kind.name == "metal":
Expand Down
121 changes: 121 additions & 0 deletions tests/python/codegen/test_target_codegen_cuda_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
from tvm.target import Target
from tvm.topi.utils import get_const_tuple

from tvm.script import ir as I, relax as R, tir as T

try:
import ml_dtypes
except ImportError:
ml_dtypes = None


@tvm.testing.requires_cuda_compute_version(9)
def test_e4m3_conversions():
Expand Down Expand Up @@ -814,5 +821,119 @@ def func(A: T.Buffer((4,), dtype)) -> None:
tvm.build(mod, target="cuda")


num_experts = 8
reduce_size = 1792
spatial_size = 4096


@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes to be installed")
def test_moe_gemv_shfl_down_illegal_instr():
global num_experts
global reduce_size
global spatial_size

@I.ir_module
class SingleBatchMoE_float8_e4m3:
@T.prim_func(private=True)
def moe_dequantize_gemv(
x_handle: T.handle,
w: T.Buffer((num_experts, spatial_size, reduce_size), "e4m3_float8"),
scale: T.Buffer((1,), "float16"),
indptr: T.Buffer((1, 2), "int32"),
o: T.Buffer((2, spatial_size), "float16"),
):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
num_seq = T.int64()
x = T.match_buffer(x_handle, (num_seq, reduce_size), "float16")
for expert_id in T.thread_binding(2, thread="blockIdx.y"):
with T.block("gemv_o"):
e = T.axis.spatial(2, expert_id)
T.reads(
w[indptr[0, e], 0:spatial_size, 0:reduce_size],
indptr[0, e],
scale[0],
x[e, 0:reduce_size],
)
T.writes(o[e, 0:spatial_size])
y = T.alloc_buffer((spatial_size, reduce_size), "float16")
for i1, i2 in T.grid(spatial_size, reduce_size):
with T.block("dequantize"):
i, j = T.axis.remap("SS", [i1, i2])
T.reads(w[indptr[0, e], i, j], indptr[0, e], scale[0])
T.writes(y[i, j])
y[i, j] = T.Cast("float16", w[indptr[0, e], i, j]) * scale[0]
for i1, i2 in T.grid(spatial_size, reduce_size):
with T.block("gemv"):
i, j = T.axis.remap("SR", [i1, i2])
T.reads(x[e, j], y[i, j])
T.writes(o[e, i])
with T.init():
o[e, i] = T.float16(0)
o[e, i] = o[e, i] + x[e, j] * y[i, j]

@R.function
def main(
x: R.Tensor(("num_seq", reduce_size), dtype="float16"),
indptr: R.Tensor((1, 2), dtype="int32"),
weight: R.Tensor((num_experts, spatial_size, reduce_size), dtype="e4m3_float8"),
scale: R.Tensor((1,), dtype="float32"),
) -> R.Tensor((2, spatial_size), dtype="float16"):
num_seq = T.int64()
R.func_attr({"num_input": 2})
cls = SingleBatchMoE_float8_e4m3
with R.dataflow():
astype: R.Tensor((1,), dtype="float16") = R.astype(scale, dtype="float16")
lv = R.call_tir(
cls.moe_dequantize_gemv,
(x, weight, astype, indptr),
out_sinfo=R.Tensor((2, spatial_size), dtype="float16"),
)
gv: R.Tensor((2, spatial_size), dtype="float16") = lv
R.output(gv)
return gv

def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.dlight.ApplyDefaultSchedule(
tvm.dlight.gpu.Matmul(),
tvm.dlight.gpu.GEMV(),
tvm.dlight.gpu.Reduction(),
tvm.dlight.gpu.GeneralReduction(),
tvm.dlight.gpu.Fallback(),
),
]
)
mod = seq(mod)
return mod

mod = SingleBatchMoE_float8_e4m3

target = tvm.target.Target("cuda")
with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": False}) and target:
mod = _pipeline(mod)
rt_mod = tvm.relax.build(mod, target=target)
dev = tvm.cuda(0)

x_data = np.zeros((1, reduce_size), dtype=np.float16)
x = tvm.nd.array(x_data, device=dev)

indptr_data = np.zeros((1, 2), dtype=np.int32)
indptr = tvm.nd.array(indptr_data, device=dev)

weight_data = np.zeros((num_experts, spatial_size, reduce_size), dtype="float8_e4m3fn")
weight = tvm.nd.array(weight_data, device=dev)

scale_data = np.zeros((1,), dtype=np.float32)
scale = tvm.nd.array(scale_data, device=dev)

vm = relax.VirtualMachine(rt_mod, dev)
# Ensure this runs without failure. Utilizing dlight thread extents TS, TR = 4, 64
# in GEMV scheduling will yield: CUDA: an illegal instruction was encountered.
vm["main"](x, indptr, weight, scale)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit be96146

Please sign in to comment.