From e4ac0d035d00f3e61fa29f0f0690bd6f7c595953 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 18:55:47 -0400 Subject: [PATCH 01/20] Commit scatter2scatter --- scattermoe/kernels/ops.py | 361 +++++++++++++++++++++++++-------- scattermoe/parallel_experts.py | 2 + 2 files changed, 283 insertions(+), 80 deletions(-) diff --git a/scattermoe/kernels/ops.py b/scattermoe/kernels/ops.py index 2386b6a..1f66e41 100644 --- a/scattermoe/kernels/ops.py +++ b/scattermoe/kernels/ops.py @@ -6,6 +6,7 @@ BLOCK_M = 128 ALLOW_TF32 = False + @torch.library.custom_op("scattermoe::bincount", mutates_args={}) def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: return x.bincount(minlength=minlength) @@ -21,53 +22,57 @@ def flatten_and_sort(expert_idxs:torch.Tensor): return sorted_expert_idxs, sorted_scattered_idxs @torch.compile -def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) : - expert_counts = compileable_bincount(sorted_experts_idxs, minlength=k) +def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): + # there is an overhead of launching a custom op so we only use the custom op when compiling + if torch.compiler.is_compiling(): + expert_counts = compileable_bincount(sorted_experts_idxs, k) + else: + expert_counts = sorted_experts_idxs.bincount(minlength=k) + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 padded_expert_block_end = padded_block_counts.cumsum(-1) expert_boundaries_end = expert_counts.cumsum(-1) expert_boundaries_start = expert_boundaries_end - expert_counts padded_expert_block_start = padded_expert_block_end - padded_block_counts - block_idxs = torch.arange(padded_expert_block_end[-1], - dtype=sorted_experts_idxs.dtype, - device=sorted_experts_idxs.device) - block_mask = ( - (block_idxs[:, None] < padded_expert_block_start) | - (block_idxs[:, None] >= padded_expert_block_end) - ) - expanded_block_idxs = ( - N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) + - expert_boundaries_start - ) - expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) - return expanded_block_idxs, expert_boundaries_end + block_idxs = torch.arange( + padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device + ).unsqueeze(1) + block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) + expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) -def _scatter2scatter_configs(): - return [ - triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4), - ] + return expanded_block_idxs, expert_boundaries_end -@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], ) -@triton.heuristics({ - "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, - "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, -}) + +@triton.autotune( + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + key=["N", "K"], +) +# @triton.heuristics({ +# "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, +# "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, +# }) @triton.jit -def _scatter2scatter( +def scatter2scatter_triton_kernel( X_ptr, stride_xm, stride_xk, W_ptr, stride_we, stride_wk, stride_wn, Y_ptr, stride_ym, stride_yn, - grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr, - FAN_OUT: tl.constexpr, - M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + grouped_idx_ptr, + expert_idxs_ptr, + block_start_idx_ptr, + FAN_OUT, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr, - OUT_M, allow_tf32: tl.constexpr, - x_grouped: tl.constexpr, y_grouped: tl.constexpr, - NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr + x_grouped, y_grouped, ): pid = tl.program_id(axis=0) @@ -76,12 +81,13 @@ def _scatter2scatter( N_block_id = pid % N_BLOCK_COUNT M_range = tl.arange(0, BLOCK_M) block_start_idx = tl.load(block_start_idx_ptr + M_block_id) - # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) E_idx = tl.min(E_idxs) E_mask = E_idxs == E_idx M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + if x_grouped: M_in_idx = M_block else: @@ -94,20 +100,23 @@ def _scatter2scatter( K_block = tl.arange(0, BLOCK_K) - N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) N_mask = N_block < N - # N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) - # N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) iters = tl.cdiv(K, BLOCK_K) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + for K_block_id in range(0, iters): - if NO_K_MASK: + if no_k_mask: x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) - if NO_N_MASK or K_block_id < (iters - 1): + + if no_n_mask or K_block_id < (iters - 1): w = tl.load(W_blk_ptrs) else: w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) @@ -115,6 +124,7 @@ def _scatter2scatter( K_mask = (K_block_id * BLOCK_K + K_block) < K x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + X_blk_ptrs += BLOCK_K * stride_xk W_blk_ptrs += BLOCK_K * stride_wk acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) @@ -122,65 +132,256 @@ def _scatter2scatter( Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) -def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k, - padded_block_idxs, x_grouped=False, y_grouped=False, - out=None): - assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) - assert sorted_scattered_idxs.size(0) == X.size(0) * k - # Pre-kernel setup - x_dim = X.size(-1) - y_dim = W.size(-1) - L_scattered = sorted_expert_idxs.size(0) - if out is None: - O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + +@triton.autotune( + configs=[ + # different block M and reducing stages + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4), + ], + key=["N", "K"], +) +@triton.jit +def groupXtY_triton_kernel( + DY_ptr, stride_dym, stride_dyk, + X_ptr, stride_xm, stride_xn, + DW_ptr, stride_dwe, stride_dwk, stride_dwn, + expert_offsets_ptr, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 else: - assert out.size(0) == L_scattered and out.size(1) == y_dim - O = out + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) - # with torch.cuda.device(X.device): - scatter2scatter_compileable(O, W, X, k, padded_block_idxs, sorted_expert_idxs, sorted_scattered_idxs, - x_grouped, y_grouped) - return O + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) -@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"O"}) -def scatter2scatter_compileable( - O: torch.Tensor, - W: torch.Tensor, - X: torch.Tensor, - k: int, - padded_block_idxs: torch.Tensor, - sorted_expert_idxs: torch.Tensor, - sorted_scattered_idxs: torch.Tensor, - x_grouped: bool, y_grouped: bool) -> None: - def grid(META): - grid_num = ( - padded_block_idxs.size(0) * - triton.cdiv(META['N'], META['BLOCK_N']), - ) - return grid_num + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) - _scatter2scatter[grid]( + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + + if no_k_mask: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + + if no_n_mask: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + + +@triton.autotune(configs=[triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4)], key=["K"]) +@triton.jit +def group_triton_kernel( + src_ptr, stride_sn, stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT: tl.constexpr, + tgt_ptr, stride_tn, stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + no_k_mask = K % BLOCK_K == 0 + + for i in range(0, iters): + if no_k_mask or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=mask) + + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti + + + + +def _scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * FAN_OUT + assert out.size(0) == sorted_expert_idxs.size(0) + assert out.size(1) == W.size(-1) + + grid = lambda meta: (padded_block_idxs.size(0) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + scatter2scatter_triton_kernel[grid]( # X_ptr, stride_xm, stride_xk, - X, X.stride(0), X.stride(1), + X, + X.stride(0), + X.stride(1), # W_ptr, stride_we, stride_wk, stride_wn, - W, W.stride(0), W.stride(1), W.stride(2), + W, + W.stride(0), + W.stride(1), + W.stride(2), # Y_ptr, stride_ym, stride_yn, - O, O.stride(0), O.stride(1), + out, + out.stride(0), + out.stride(1), grouped_idx_ptr=sorted_scattered_idxs, expert_idxs_ptr=sorted_expert_idxs, block_start_idx_ptr=padded_block_idxs, - FAN_OUT=k, + FAN_OUT=FAN_OUT, M=X.size(0), K=X.size(1), - N=O.size(1), E=W.size(0), + N=out.size(1), + E=W.size(0), BLOCK_M=BLOCK_M, ACC_TYPE=tl.float32, - OUT_M=O.size(0), - allow_tf32=ALLOW_TF32, - x_grouped=x_grouped, y_grouped=y_grouped, + allow_tf32=torch.backends.cudnn.allow_tf32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"scattermoe::scatter2scatter", mutates_args={"out"}) +def _scatter2scatter_compileable( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, ) +def scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + if torch.compiler.is_compiling(): + _scatter2scatter_compileable( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + else: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + def _config_XtY(): return [ diff --git a/scattermoe/parallel_experts.py b/scattermoe/parallel_experts.py index 6621899..bc929cb 100644 --- a/scattermoe/parallel_experts.py +++ b/scattermoe/parallel_experts.py @@ -11,8 +11,10 @@ def forward( gates=None, grouped_in=False, grouped_out=False, ): with torch.device(x.device): + O = torch.empty((sorted_expert_idxs.size(0), W.size(-1)), device=X.device, dtype=X.dtype) output = kernels.ops.scatter2scatter( X=x, W=expert_weights, + out=O, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, padded_block_idxs=padded_block_idxs, From c9ca747f71e3199d57c0eff3a45ba980decbc5c4 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 18:59:30 -0400 Subject: [PATCH 02/20] Modified. --- scattermoe/kernels/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scattermoe/kernels/ops.py b/scattermoe/kernels/ops.py index 1f66e41..b12727e 100644 --- a/scattermoe/kernels/ops.py +++ b/scattermoe/kernels/ops.py @@ -322,7 +322,7 @@ def _scatter2scatter( # custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 -@torch_custom_op(f"scattermoe::scatter2scatter", mutates_args={"out"}) +@torch.library.custom_op(f"scattermoe::scatter2scatter", mutates_args={"out"}) def _scatter2scatter_compileable( X: torch.Tensor, W: torch.Tensor, From 495c0eea3768230472b737fb4bd8077331524c50 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 20:13:10 -0400 Subject: [PATCH 03/20] Imported khd. --- scattermoe/kernels/__init__.py | 27 +- scattermoe/kernels/compileable_ops.py | 234 ++++++++++ scattermoe/kernels/ops.py | 590 -------------------------- scattermoe/kernels/triton.py | 250 +++++++++++ scattermoe/mlp.py | 4 +- scattermoe/parallel_experts.py | 42 +- 6 files changed, 540 insertions(+), 607 deletions(-) create mode 100644 scattermoe/kernels/compileable_ops.py delete mode 100644 scattermoe/kernels/ops.py create mode 100644 scattermoe/kernels/triton.py diff --git a/scattermoe/kernels/__init__.py b/scattermoe/kernels/__init__.py index 14c0c07..bb2ba7f 100644 --- a/scattermoe/kernels/__init__.py +++ b/scattermoe/kernels/__init__.py @@ -1,2 +1,25 @@ -from . import ops -from . import single \ No newline at end of file +from . import compileable_ops as ops +from . import single + +def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): + # there is an overhead of launching a custom op so we only use the custom op when compiling + if torch.compiler.is_compiling(): + expert_counts = compileable_bincount(sorted_experts_idxs, k) + else: + expert_counts = sorted_experts_idxs.bincount(minlength=k) + + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + + block_idxs = torch.arange( + padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device + ).unsqueeze(1) + + block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) + expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + + return expanded_block_idxs, expert_boundaries_end diff --git a/scattermoe/kernels/compileable_ops.py b/scattermoe/kernels/compileable_ops.py new file mode 100644 index 0000000..82b6685 --- /dev/null +++ b/scattermoe/kernels/compileable_ops.py @@ -0,0 +1,234 @@ +import torch +import triton +import triton.language as tl + +from torch.library import custom_op as torch_custom_op +from .triton import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel + + +LIBRARY_NAME = "scattermoe" +BLOCK_M = 128 +torch._dynamo.config.capture_scalar_outputs = True + + +# bincount is not compilable +@torch_custom_op(f"{LIBRARY_NAME}::bincount", mutates_args={}) +def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength) + + +@compileable_bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, dtype=torch.long, device=x.device) + + +def _scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * FAN_OUT + assert out.size(0) == sorted_expert_idxs.size(0) + assert out.size(1) == W.size(-1) + + grid = lambda meta: (padded_block_idxs.size(0) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + scatter2scatter_triton_kernel[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + out, + out.stride(0), + out.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + block_start_idx_ptr=padded_block_idxs, + FAN_OUT=FAN_OUT, + M=X.size(0), + K=X.size(1), + N=out.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::scatter2scatter", mutates_args={"out"}) +def _scatter2scatter_compileable( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + if torch.compiler.is_compiling(): + _scatter2scatter_compileable( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + else: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def _group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + grid = lambda meta: (E * triton.cdiv(meta["K"], meta["BLOCK_K"]), triton.cdiv(meta["N"], meta["BLOCK_N"])) + + groupXtY_triton_kernel[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + N=DY.size(-1), + K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group_bwd_W", mutates_args={"DW"}) +def _group_bwd_W_compileable( + DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int +) -> None: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + if torch.compiler.is_compiling(): + _group_bwd_W_compileable(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + else: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def _group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + + grid = lambda meta: (triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + group_triton_kernel[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + coeff is not None, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + out, + out.stride(0), + out.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group", mutates_args={"out"}) +def _group_compileable( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) + + +def group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + if torch.compiler.is_compiling(): + _group_compileable(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) + else: + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) \ No newline at end of file diff --git a/scattermoe/kernels/ops.py b/scattermoe/kernels/ops.py deleted file mode 100644 index b12727e..0000000 --- a/scattermoe/kernels/ops.py +++ /dev/null @@ -1,590 +0,0 @@ -import torch -import triton -import triton.language as tl -from torch.nn import functional as F - -BLOCK_M = 128 -ALLOW_TF32 = False - - -@torch.library.custom_op("scattermoe::bincount", mutates_args={}) -def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: - return x.bincount(minlength=minlength) - -@compileable_bincount.register_fake -def _(x: torch.Tensor, minlength: int) -> torch.Tensor: - return torch.empty(minlength, dtype=torch.long, device=x.device) - -@torch.compile -def flatten_and_sort(expert_idxs:torch.Tensor): - flattened_expert_idxs = expert_idxs.flatten() - sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) - return sorted_expert_idxs, sorted_scattered_idxs - -@torch.compile -def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): - # there is an overhead of launching a custom op so we only use the custom op when compiling - if torch.compiler.is_compiling(): - expert_counts = compileable_bincount(sorted_experts_idxs, k) - else: - expert_counts = sorted_experts_idxs.bincount(minlength=k) - - padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 - padded_expert_block_end = padded_block_counts.cumsum(-1) - expert_boundaries_end = expert_counts.cumsum(-1) - expert_boundaries_start = expert_boundaries_end - expert_counts - padded_expert_block_start = padded_expert_block_end - padded_block_counts - - block_idxs = torch.arange( - padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device - ).unsqueeze(1) - - block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) - expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start - expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) - - return expanded_block_idxs, expert_boundaries_end - - -@triton.autotune( - configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], - key=["N", "K"], -) -# @triton.heuristics({ -# "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, -# "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, -# }) -@triton.jit -def scatter2scatter_triton_kernel( - X_ptr, stride_xm, stride_xk, - W_ptr, stride_we, stride_wk, stride_wn, - Y_ptr, stride_ym, stride_yn, - grouped_idx_ptr, - expert_idxs_ptr, - block_start_idx_ptr, - FAN_OUT, - M, - K: tl.constexpr, - N: tl.constexpr, - E: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - allow_tf32: tl.constexpr, - x_grouped, y_grouped, -): - pid = tl.program_id(axis=0) - - N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) - M_block_id = pid // N_BLOCK_COUNT - N_block_id = pid % N_BLOCK_COUNT - M_range = tl.arange(0, BLOCK_M) - block_start_idx = tl.load(block_start_idx_ptr + M_block_id) - - M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) - E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) - E_idx = tl.min(E_idxs) - E_mask = E_idxs == E_idx - M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) - - if x_grouped: - M_in_idx = M_block - else: - M_in_idx = M_idx // FAN_OUT - - if y_grouped: - M_out_idx = M_block - else: - M_out_idx = M_idx - - K_block = tl.arange(0, BLOCK_K) - - N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - N_mask = N_block < N - - X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk - W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - iters = tl.cdiv(K, BLOCK_K) - - no_k_mask = K % BLOCK_K == 0 - no_n_mask = N % BLOCK_N == 0 - - for K_block_id in range(0, iters): - if no_k_mask: - x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) - - if no_n_mask or K_block_id < (iters - 1): - w = tl.load(W_blk_ptrs) - else: - w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) - else: - K_mask = (K_block_id * BLOCK_K + K_block) < K - x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) - w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) - - X_blk_ptrs += BLOCK_K * stride_xk - W_blk_ptrs += BLOCK_K * stride_wk - acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) - - Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) - tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) - - -@triton.autotune( - configs=[ - # different block M and reducing stages - triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4), - ], - key=["N", "K"], -) -@triton.jit -def groupXtY_triton_kernel( - DY_ptr, stride_dym, stride_dyk, - X_ptr, stride_xm, stride_xn, - DW_ptr, stride_dwe, stride_dwk, stride_dwn, - expert_offsets_ptr, - K: tl.constexpr, - N: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - allow_tf32: tl.constexpr, -): - pid0 = tl.program_id(axis=0) - pid1 = tl.program_id(axis=1) - num0 = tl.num_programs(0) - num1 = tl.num_programs(1) - pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) - - K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) - E_idx = pid0 // K_BLOCK_COUNT - K_block_id = pid0 % K_BLOCK_COUNT - N_block_id = pid1 - - if E_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) - - end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) - - if end_idx > start_idx: - M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) - - K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) - K_mask = K_block < K - K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) - - N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - N_mask = N_block < N - N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) - - M_idxs = M_block - xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm - dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk - - acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) - iters = tl.cdiv(end_idx - start_idx, BLOCK_M) - - no_k_mask = K % BLOCK_K == 0 - no_n_mask = N % BLOCK_N == 0 - - for i in range(0, iters): - M_mask = (i * BLOCK_M + M_block) < end_idx - - if no_k_mask: - xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) - else: - xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) - - if no_n_mask: - dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) - else: - dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) - - xt_blk_ptrs += BLOCK_M * stride_xm - dy_blk_ptrs += BLOCK_M * stride_dym - acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) - - DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn - acc = acc.to(DW_blk_ptrs.dtype.element_ty) - tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) - - -@triton.autotune(configs=[triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4)], key=["K"]) -@triton.jit -def group_triton_kernel( - src_ptr, stride_sn, stride_sk, - has_coeff: tl.constexpr, - coeff_ptr, - FAN_OUT: tl.constexpr, - tgt_ptr, stride_tn, stride_ti, - grouped_idx_ptr, - N, - K: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(axis=0) - - N_block_id = pid - N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - N_mask = N_blk < N - N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) - N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) - - K_blk = tl.arange(0, BLOCK_K) - src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk - tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti - - if has_coeff: - c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] - - iters = tl.cdiv(K, BLOCK_K) - no_k_mask = K % BLOCK_K == 0 - - for i in range(0, iters): - if no_k_mask or i < iters - 1: - block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) - - if has_coeff: - block *= c - - tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) - else: - K_mask = (i * BLOCK_K + K_blk) < K - mask = N_mask[:, None] & K_mask[None, :] - block = tl.load(src_blk_ptrs, mask=mask) - - if has_coeff: - block *= c - - tl.store(tgt_blk_ptrs, block, mask=mask) - - src_blk_ptrs += BLOCK_K * stride_sk - tgt_blk_ptrs += BLOCK_K * stride_ti - - - - -def _scatter2scatter( - X: torch.Tensor, - W: torch.Tensor, - sorted_expert_idxs: torch.Tensor, - sorted_scattered_idxs: torch.Tensor, - padded_block_idxs: torch.Tensor, - out: torch.Tensor, - FAN_OUT: int, - x_grouped: bool = False, - y_grouped: bool = False, -) -> None: - assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) - assert sorted_scattered_idxs.size(0) == X.size(0) * FAN_OUT - assert out.size(0) == sorted_expert_idxs.size(0) - assert out.size(1) == W.size(-1) - - grid = lambda meta: (padded_block_idxs.size(0) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) - - scatter2scatter_triton_kernel[grid]( - # X_ptr, stride_xm, stride_xk, - X, - X.stride(0), - X.stride(1), - # W_ptr, stride_we, stride_wk, stride_wn, - W, - W.stride(0), - W.stride(1), - W.stride(2), - # Y_ptr, stride_ym, stride_yn, - out, - out.stride(0), - out.stride(1), - grouped_idx_ptr=sorted_scattered_idxs, - expert_idxs_ptr=sorted_expert_idxs, - block_start_idx_ptr=padded_block_idxs, - FAN_OUT=FAN_OUT, - M=X.size(0), - K=X.size(1), - N=out.size(1), - E=W.size(0), - BLOCK_M=BLOCK_M, - ACC_TYPE=tl.float32, - allow_tf32=torch.backends.cudnn.allow_tf32, - x_grouped=x_grouped, - y_grouped=y_grouped, - ) - - -# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 -@torch.library.custom_op(f"scattermoe::scatter2scatter", mutates_args={"out"}) -def _scatter2scatter_compileable( - X: torch.Tensor, - W: torch.Tensor, - sorted_expert_idxs: torch.Tensor, - sorted_scattered_idxs: torch.Tensor, - padded_block_idxs: torch.Tensor, - out: torch.Tensor, - FAN_OUT: int, - x_grouped: bool = False, - y_grouped: bool = False, -) -> None: - _scatter2scatter( - X=X, - W=W, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, - out=out, - FAN_OUT=FAN_OUT, - x_grouped=x_grouped, - y_grouped=y_grouped, - ) - -def scatter2scatter( - X: torch.Tensor, - W: torch.Tensor, - sorted_expert_idxs: torch.Tensor, - sorted_scattered_idxs: torch.Tensor, - padded_block_idxs: torch.Tensor, - out: torch.Tensor, - FAN_OUT: int, - x_grouped: bool = False, - y_grouped: bool = False, -) -> None: - if torch.compiler.is_compiling(): - _scatter2scatter_compileable( - X=X, - W=W, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, - out=out, - FAN_OUT=FAN_OUT, - x_grouped=x_grouped, - y_grouped=y_grouped, - ) - else: - _scatter2scatter( - X=X, - W=W, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, - out=out, - FAN_OUT=FAN_OUT, - x_grouped=x_grouped, - y_grouped=y_grouped, - ) - - -def _config_XtY(): - return [ - triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4), - ] - -def group_bwd_W(DY, X, expert_offsets, E): - DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) - DW = DWt.permute(0, 2, 1) - groupXtY_compileable(E, DW, DY, X, expert_offsets) - return DW - - -@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"}) -def groupXtY_compileable( - E: int, - DW: torch.Tensor, - DY: torch.Tensor, - X: torch.Tensor, - expert_offsets: torch.Tensor) -> None: - def grid(META): - grid = ( - E * triton.cdiv(META['K'], META['BLOCK_K']), - triton.cdiv(META['N'], META['BLOCK_N']), - ) - return grid - - _groupXtY[grid]( - # DY_ptr, stride_dym, stride_dyk, - DY, DY.stride(0), DY.stride(1), - # X_ptr, stride_xm, stride_xn, - X, X.stride(0), X.stride(1), - # DW_ptr, stride_dwe, stride_dwk, stride_dwn, - DW, DW.stride(0), DW.stride(1), DW.stride(2), - # expert_offsets_ptr, - expert_offsets, - # K: tl.constexpr, N: tl.constexpr, - M=DY.size(0), N=DY.size(-1), K=X.size(-1), - # ACC_TYPE: tl.constexpr, - ACC_TYPE=tl.float32, - allow_tf32=True - ) - - -@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], ) -@triton.heuristics({ - "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, - "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, -}) -@triton.jit -def _groupXtY( - DY_ptr, stride_dym, stride_dyk, - X_ptr, stride_xm, stride_xn, - DW_ptr, stride_dwe, stride_dwk, stride_dwn, - expert_offsets_ptr, - M, K: tl.constexpr, N: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - allow_tf32: tl.constexpr, - NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr -): - pid0 = tl.program_id(axis=0) - pid1 = tl.program_id(axis=1) - num0 = tl.num_programs(0) - num1 = tl.num_programs(1) - # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) - pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) - - K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) - E_idx = pid0 // K_BLOCK_COUNT - K_block_id = pid0 % K_BLOCK_COUNT - N_block_id = pid1 - - if E_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) - end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) - - if end_idx > start_idx: - M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) - - K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) - K_mask = K_block < K - K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) - - N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - N_mask = N_block < N - N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) - - M_idxs = M_block - xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm - dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk - - acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) - iters = tl.cdiv(end_idx - start_idx, BLOCK_M) - for i in range(0, iters): - M_mask = (i * BLOCK_M + M_block) < end_idx - if NO_K_MASK: - xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) - else: - xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) - if NO_N_MASK: - dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) - else: - dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) - # acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) - xt_blk_ptrs += BLOCK_M * stride_xm - dy_blk_ptrs += BLOCK_M * stride_dym - acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) - - - - DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn - acc = acc.to(DW_blk_ptrs.dtype.element_ty) - tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) - - -def _config_grouping(): - return [ - triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4), - # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), - # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4), - ] - -def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): - N = sorted_expert_idxs.size(0) - K = A.size(1) - assert A.size(0) * fan_out == N - if out is not None: - Y = out - else: - Y = torch.empty((N, K), dtype=A.dtype, device=A.device) - group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs) - return Y - - -@torch.library.custom_op("scattermoe::group", mutates_args={"Y"}) -def group_compileable( - A: torch.Tensor, - K: int, - N: int, - Y: torch.Tensor, - coeff: torch.Tensor, has_coeff: bool, - fan_out: int, - sorted_expert_idxs: torch.Tensor) -> None: - def grid(META): - grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),) - return grid_num - _group[grid]( - # A_ptr, stride_an, stride_ai, - A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out, - # Y_ptr, stride_yn, stride_yk, - Y, Y.stride(0), Y.stride(1), - # grouped_idx_ptr, - sorted_expert_idxs, - # N: tl.constexpr, K: tl.constexpr, - N, K - ) - - -@triton.autotune(configs=_config_grouping(), key=['K']) -@triton.heuristics({ - "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0 -}) -@triton.jit -def _group( - src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr, - tgt_ptr, stride_tn, stride_ti, - grouped_idx_ptr, - N, K: tl.constexpr, - BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NO_K_MASK: tl.constexpr -): - pid = tl.program_id(axis=0) - - N_block_id = pid - N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - N_mask = N_blk < N - N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) - N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) - - K_blk = tl.arange(0, BLOCK_K) - src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk - tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti - - if has_coeff: - c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] - - iters = tl.cdiv(K, BLOCK_K) - for i in range(0, iters): - if NO_K_MASK or i < iters - 1: - block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) - if has_coeff: - block *= c - tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) - - else: - K_mask = (i * BLOCK_K + K_blk) < K - mask = N_mask[:, None] & K_mask[None, :] - block = tl.load(src_blk_ptrs, mask=mask) - if has_coeff: - block *= c - tl.store(tgt_blk_ptrs, block, mask=mask) - src_blk_ptrs += BLOCK_K * stride_sk - tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/scattermoe/kernels/triton.py b/scattermoe/kernels/triton.py new file mode 100644 index 0000000..b59f1fd --- /dev/null +++ b/scattermoe/kernels/triton.py @@ -0,0 +1,250 @@ +import triton +import triton.language as tl + + +BLOCK_M = 128 + + +@triton.autotune( + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + key=["N", "K"], +) +@triton.jit +def scatter2scatter_triton_kernel( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + Y_ptr, + stride_ym, + stride_yn, + grouped_idx_ptr, + expert_idxs_ptr, + block_start_idx_ptr, + FAN_OUT, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + x_grouped, + y_grouped, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(block_start_idx_ptr + M_block_id) + + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(K, BLOCK_K) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for K_block_id in range(0, iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + + if no_n_mask or K_block_id < (iters - 1): + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +@triton.autotune( + configs=[ + # different block M and reducing stages + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4), + # keep 4 stages and keep two 64 block sizes + # - NOTE: these can get good performances for low M, but for large M the variation + # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + ], + key=["N", "K"], +) +@triton.jit +def groupXtY_triton_kernel( + DY_ptr, + stride_dym, + stride_dyk, + X_ptr, + stride_xm, + stride_xn, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + expert_offsets_ptr, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + + if no_k_mask: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + + if no_n_mask: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + + +@triton.autotune(configs=[triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4)], key=["K"]) +@triton.jit +def group_triton_kernel( + src_ptr, + stride_sn, + stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT, + tgt_ptr, + stride_tn, + stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + no_k_mask = K % BLOCK_K == 0 + + for i in range(0, iters): + if no_k_mask or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=mask) + + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/scattermoe/mlp.py b/scattermoe/mlp.py index ef6d0a6..b6fac94 100644 --- a/scattermoe/mlp.py +++ b/scattermoe/mlp.py @@ -31,8 +31,8 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te x_shape = x.size() x = x.view(-1, x_shape[-1]) with torch.no_grad(): - sorted_expert_idxs, sorted_scattered_idxs = kernels.ops.flatten_and_sort(expert_idxs) - padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts) + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten()) + padded_block_idxs, expert_offsets = kernels.padded_block_indices(sorted_expert_idxs, self.num_experts) h, gates = self.experts( x, self.top_k, diff --git a/scattermoe/parallel_experts.py b/scattermoe/parallel_experts.py index bc929cb..653420d 100644 --- a/scattermoe/parallel_experts.py +++ b/scattermoe/parallel_experts.py @@ -11,8 +11,8 @@ def forward( gates=None, grouped_in=False, grouped_out=False, ): with torch.device(x.device): - O = torch.empty((sorted_expert_idxs.size(0), W.size(-1)), device=X.device, dtype=X.dtype) - output = kernels.ops.scatter2scatter( + output = torch.empty((sorted_expert_idxs.size(0), W.size(-1)), device=X.device, dtype=X.dtype) + kernels.ops.scatter2scatter( X=x, W=expert_weights, out=O, sorted_expert_idxs=sorted_expert_idxs, @@ -58,7 +58,6 @@ def backward(ctx, grad_out): d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) gates_flat = gates.flatten() gate_fan = gates.size(1) - # print("expanded and grouping") grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later else: d_gates = None @@ -69,29 +68,46 @@ def backward(ctx, grad_out): if grouped_out: grouped_grad_out = grad_out else: - grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs, - fan_out=gate_fan, coeff=gates_flat, - out=grouped_grad_out) + kernels.ops.group( + A=grad_out, + sorted_expert_idxs=sorted_expert_idxs, + out=grouped_grad_out, + coeff=gates_flat, + fan_out=gate_fan, + ) if grouped_in: grouped_x = x d_expanded_input = None else: grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k) d_expanded_input = grouped_x - d_weights = kernels.ops.group_bwd_W( - DY=grouped_grad_out, X=grouped_x, + + d_weights = torch.zeros( + expert_weights.size(0), + grouped_grad_out.size(-1), + grouped_x.size(-1), + device=grouped_grad_out.device, + dtype=grouped_grad_out.dtype, + ).permute(0, 2, 1) + + kernels.ops.grouped_bwd_W( + DY=grouped_grad_out, + X=grouped_x, expert_offsets=expert_offsets, + DW=d_weights, E=expert_weights.size(0) ) - d_expanded_input = kernels.ops.scatter2scatter( - X=grouped_grad_out, x_grouped=True, + + kernels.ops.scatter2scatter( + X=grouped_grad_out, W=expert_weights.permute(0, 2, 1), - padded_block_idxs=padded_block_idxs, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - k=1, + padded_block_idxs=padded_block_idxs, + out=d_expanded_input, # Reuse grouped_x buffer + FAN_OUT=1, + x_grouped=True, y_grouped=grouped_in, - out=d_expanded_input # Reuse grouped_x buffer ) if k == 1: From 0b9f87ddd943a5569e57f4240e823dd2f758e550 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 20:15:10 -0400 Subject: [PATCH 04/20] Modified. --- scattermoe/kernels/__init__.py | 1 + scattermoe/kernels/compileable_ops.py | 2 +- scattermoe/kernels/triton.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scattermoe/kernels/__init__.py b/scattermoe/kernels/__init__.py index bb2ba7f..3a6cb49 100644 --- a/scattermoe/kernels/__init__.py +++ b/scattermoe/kernels/__init__.py @@ -1,6 +1,7 @@ from . import compileable_ops as ops from . import single +BLOCK_M = 128 def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): # there is an overhead of launching a custom op so we only use the custom op when compiling if torch.compiler.is_compiling(): diff --git a/scattermoe/kernels/compileable_ops.py b/scattermoe/kernels/compileable_ops.py index 82b6685..ffe0a11 100644 --- a/scattermoe/kernels/compileable_ops.py +++ b/scattermoe/kernels/compileable_ops.py @@ -4,10 +4,10 @@ from torch.library import custom_op as torch_custom_op from .triton import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel +from . import BLOCK_M LIBRARY_NAME = "scattermoe" -BLOCK_M = 128 torch._dynamo.config.capture_scalar_outputs = True diff --git a/scattermoe/kernels/triton.py b/scattermoe/kernels/triton.py index b59f1fd..d0c7467 100644 --- a/scattermoe/kernels/triton.py +++ b/scattermoe/kernels/triton.py @@ -2,7 +2,6 @@ import triton.language as tl -BLOCK_M = 128 @triton.autotune( From 0b39d130b2378f7013735da972d11b67226950d0 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 20:16:51 -0400 Subject: [PATCH 05/20] Modified. --- scattermoe/kernels/__init__.py | 3 ++- scattermoe/kernels/compileable_ops.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scattermoe/kernels/__init__.py b/scattermoe/kernels/__init__.py index 3a6cb49..8077919 100644 --- a/scattermoe/kernels/__init__.py +++ b/scattermoe/kernels/__init__.py @@ -1,7 +1,8 @@ from . import compileable_ops as ops from . import single -BLOCK_M = 128 +BLOCK_M = ops.BLOCK_M + def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): # there is an overhead of launching a custom op so we only use the custom op when compiling if torch.compiler.is_compiling(): diff --git a/scattermoe/kernels/compileable_ops.py b/scattermoe/kernels/compileable_ops.py index ffe0a11..82b6685 100644 --- a/scattermoe/kernels/compileable_ops.py +++ b/scattermoe/kernels/compileable_ops.py @@ -4,10 +4,10 @@ from torch.library import custom_op as torch_custom_op from .triton import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel -from . import BLOCK_M LIBRARY_NAME = "scattermoe" +BLOCK_M = 128 torch._dynamo.config.capture_scalar_outputs = True From 7b27c2b1ecc3f2c05e09012d52835003b355ef6d Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 20:17:55 -0400 Subject: [PATCH 06/20] Import torch --- scattermoe/kernels/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scattermoe/kernels/__init__.py b/scattermoe/kernels/__init__.py index 8077919..0267cf8 100644 --- a/scattermoe/kernels/__init__.py +++ b/scattermoe/kernels/__init__.py @@ -1,3 +1,4 @@ +import torch from . import compileable_ops as ops from . import single From 8b4046447be59ea9cc45513c098a85a4ccca4787 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Mon, 7 Oct 2024 00:45:59 +0000 Subject: [PATCH 07/20] Output wrong. --- scattermoe/kernels/compileable_ops.py | 2 +- scattermoe/mlp.py | 4 ++-- scattermoe/parallel_experts.py | 20 ++++++++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/scattermoe/kernels/compileable_ops.py b/scattermoe/kernels/compileable_ops.py index 82b6685..528fba1 100644 --- a/scattermoe/kernels/compileable_ops.py +++ b/scattermoe/kernels/compileable_ops.py @@ -231,4 +231,4 @@ def group( if torch.compiler.is_compiling(): _group_compileable(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) else: - _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) \ No newline at end of file + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) diff --git a/scattermoe/mlp.py b/scattermoe/mlp.py index b6fac94..d4761d6 100644 --- a/scattermoe/mlp.py +++ b/scattermoe/mlp.py @@ -77,8 +77,8 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te x_shape = x.size() x = x.view(-1, x_shape[-1]) with torch.no_grad(): - sorted_expert_idxs, sorted_scattered_idxs = kernels.ops.flatten_and_sort(expert_idxs) - padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts) + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten()) + padded_block_idxs, expert_offsets = kernels.padded_block_indices(sorted_expert_idxs, self.num_experts) h = self.experts( x, self.top_k, diff --git a/scattermoe/parallel_experts.py b/scattermoe/parallel_experts.py index 653420d..3fd7ebb 100644 --- a/scattermoe/parallel_experts.py +++ b/scattermoe/parallel_experts.py @@ -11,14 +11,15 @@ def forward( gates=None, grouped_in=False, grouped_out=False, ): with torch.device(x.device): - output = torch.empty((sorted_expert_idxs.size(0), W.size(-1)), device=X.device, dtype=X.dtype) + output = torch.empty((sorted_expert_idxs.size(0), expert_weights.size(-1)), + device=x.device, dtype=x.dtype) kernels.ops.scatter2scatter( X=x, W=expert_weights, - out=O, + out=output, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, padded_block_idxs=padded_block_idxs, - k=k, x_grouped=grouped_in, y_grouped=grouped_out + FAN_OUT=k, x_grouped=grouped_in, y_grouped=grouped_out ) if gates is not None: output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) @@ -77,9 +78,16 @@ def backward(ctx, grad_out): ) if grouped_in: grouped_x = x - d_expanded_input = None + d_expanded_input = torch.empty( + (grouped_grad_out.size(0), expert_weights.size(1)), + device=grouped_grad_out.device, dtype=grouped_grad_out.dtype) else: - grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k) + grouped_x = torch.empty(sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device) + kernels.ops.group( + x, sorted_scattered_idxs, + out=grouped_x, + fan_out=k + ) d_expanded_input = grouped_x d_weights = torch.zeros( @@ -90,7 +98,7 @@ def backward(ctx, grad_out): dtype=grouped_grad_out.dtype, ).permute(0, 2, 1) - kernels.ops.grouped_bwd_W( + kernels.ops.group_bwd_W( DY=grouped_grad_out, X=grouped_x, expert_offsets=expert_offsets, From 4d74534414c4652322444829227d3572827b8035 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 20:52:06 -0400 Subject: [PATCH 08/20] Check output first. --- tests/test_mlp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index e781b02..015128f 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -58,6 +58,7 @@ def test_mlp_correctness(self, length, x_dim, h_dim, E, k, dtype): err_dW1 = torch.abs(dW1_ - dW1) err_dW2 = torch.abs(dW2_ - dW2) tolerance = 1e-2 + print("Y error:", err_Y.max()) assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() assert err_dX.max() < tolerance, "dX error too large: max %0.05f" % err_dX.max() assert err_dg.max() < tolerance, "dg error too large: max %0.05f" % err_dg.max() From df02d37b0e678365cdc1e0b6a873c1c0e3b59f0f Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 20:58:21 -0400 Subject: [PATCH 09/20] Modified. --- scattermoe/parallel_experts.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/scattermoe/parallel_experts.py b/scattermoe/parallel_experts.py index 3fd7ebb..b636797 100644 --- a/scattermoe/parallel_experts.py +++ b/scattermoe/parallel_experts.py @@ -54,17 +54,17 @@ def backward(ctx, grad_out): grouped_out = ctx.grouped_out # print("backward") with torch.device(grad_out.device): - if gates is not None: - # calculate gates gradient - d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) - gates_flat = gates.flatten() - gate_fan = gates.size(1) - grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later - else: + if gates is None: d_gates = None gates_flat = None gate_fan = 1 grouped_grad_out = None + else: + # calculate gates gradient + d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later if grouped_out: grouped_grad_out = grad_out @@ -79,12 +79,16 @@ def backward(ctx, grad_out): if grouped_in: grouped_x = x d_expanded_input = torch.empty( - (grouped_grad_out.size(0), expert_weights.size(1)), - device=grouped_grad_out.device, dtype=grouped_grad_out.dtype) + (sorted_expert_idxs.size(0), expert_weights.size(1)), + device=x.device, dtype=x.dtype) else: - grouped_x = torch.empty(sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device) + grouped_x = torch.empty( + (sorted_scattered_idxs.size(0), x.size(1)), + dtype=x.dtype, device=x.device + ) kernels.ops.group( - x, sorted_scattered_idxs, + A=x, + sorted_expert_idxs=sorted_scattered_idxs, out=grouped_x, fan_out=k ) From 96f7d8d37237e2a50351c2e8ffc0c6534c7e17e3 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Mon, 7 Oct 2024 01:02:46 +0000 Subject: [PATCH 10/20] Modified test --- tests/test_mlp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 015128f..bc7efcd 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -59,8 +59,13 @@ def test_mlp_correctness(self, length, x_dim, h_dim, E, k, dtype): err_dW2 = torch.abs(dW2_ - dW2) tolerance = 1e-2 print("Y error:", err_Y.max()) + print("dg:", err_dg.max()) + print("dW1:", err_dW1.max()) + print("dW2:", err_dW2.max()) + print("dX:", err_dX.max()) assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() - assert err_dX.max() < tolerance, "dX error too large: max %0.05f" % err_dX.max() assert err_dg.max() < tolerance, "dg error too large: max %0.05f" % err_dg.max() assert err_dW1.max() < tolerance, "dW1 error too large: max %0.05f" % err_dW1.max() assert err_dW2.max() < tolerance, "dW2 error too large: max %0.05f" % err_dW2.max() + assert err_dX.max() < tolerance, "dX error too large: max %0.05f" % err_dX.max() + From 47ab090dfc0d76a8d39a5432476b532270c4bd99 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 21:43:27 -0400 Subject: [PATCH 11/20] SOmething still wrong. --- scattermoe/parallel_experts.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/scattermoe/parallel_experts.py b/scattermoe/parallel_experts.py index b636797..2ac53ec 100644 --- a/scattermoe/parallel_experts.py +++ b/scattermoe/parallel_experts.py @@ -15,21 +15,17 @@ def forward( device=x.device, dtype=x.dtype) kernels.ops.scatter2scatter( X=x, W=expert_weights, - out=output, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, padded_block_idxs=padded_block_idxs, + out=output, FAN_OUT=k, x_grouped=grouped_in, y_grouped=grouped_out ) - if gates is not None: - output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) - output = torch.bmm( - gates[:, None, :], - output_expanded - ).squeeze(1) - else: + if gates is None: output_expanded = None - + else: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = torch.bmm(gates.unsqueeze(1), output_expanded).squeeze(1) ctx.save_for_backward( x, expert_weights, sorted_expert_idxs, @@ -53,7 +49,7 @@ def backward(ctx, grad_out): grouped_in = ctx.grouped_in grouped_out = ctx.grouped_out # print("backward") - with torch.device(grad_out.device): + with torch.device(x.device): if gates is None: d_gates = None gates_flat = None From ff23d15f61ee766139b670c33d399aaef09c4111 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 22:23:47 -0400 Subject: [PATCH 12/20] SHouldn't change anything --- scattermoe/parallel_experts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scattermoe/parallel_experts.py b/scattermoe/parallel_experts.py index 2ac53ec..2540504 100644 --- a/scattermoe/parallel_experts.py +++ b/scattermoe/parallel_experts.py @@ -26,6 +26,7 @@ def forward( else: output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) output = torch.bmm(gates.unsqueeze(1), output_expanded).squeeze(1) + ctx.save_for_backward( x, expert_weights, sorted_expert_idxs, @@ -54,17 +55,16 @@ def backward(ctx, grad_out): d_gates = None gates_flat = None gate_fan = 1 - grouped_grad_out = None else: # calculate gates gradient d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1) gates_flat = gates.flatten() gate_fan = gates.size(1) - grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later if grouped_out: grouped_grad_out = grad_out else: + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later kernels.ops.group( A=grad_out, sorted_expert_idxs=sorted_expert_idxs, From 8188596f37640b6e7055ba220a1bb367dd8bd4c7 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 22:29:31 -0400 Subject: [PATCH 13/20] Turn off allow_tf32 --- scattermoe/kernels/compileable_ops.py | 5 +++-- scattermoe/kernels/triton.py | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/scattermoe/kernels/compileable_ops.py b/scattermoe/kernels/compileable_ops.py index 528fba1..41d5fd0 100644 --- a/scattermoe/kernels/compileable_ops.py +++ b/scattermoe/kernels/compileable_ops.py @@ -9,6 +9,7 @@ LIBRARY_NAME = "scattermoe" BLOCK_M = 128 torch._dynamo.config.capture_scalar_outputs = True +ALLOW_TF32 = False # bincount is not compilable @@ -64,7 +65,7 @@ def _scatter2scatter( E=W.size(0), BLOCK_M=BLOCK_M, ACC_TYPE=tl.float32, - allow_tf32=torch.backends.cudnn.allow_tf32, + allow_tf32=torch.backends.cudnn.allow_tf32 and ALLOW_TF32, x_grouped=x_grouped, y_grouped=y_grouped, ) @@ -157,7 +158,7 @@ def _group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor K=X.size(-1), # ACC_TYPE: tl.constexpr, ACC_TYPE=tl.float32, - allow_tf32=torch.backends.cudnn.allow_tf32, + allow_tf32=torch.backends.cudnn.allow_tf32 and ALLOW_TF32, ) diff --git a/scattermoe/kernels/triton.py b/scattermoe/kernels/triton.py index d0c7467..89274ce 100644 --- a/scattermoe/kernels/triton.py +++ b/scattermoe/kernels/triton.py @@ -2,8 +2,6 @@ import triton.language as tl - - @triton.autotune( configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], key=["N", "K"], From 78985d91c56ae7a89d91e0050b5ca8d515c53653 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 22:57:09 -0400 Subject: [PATCH 14/20] Direct copy. --- scattermoe/mlp.py | 5 +- scattermoe/triton_implementation/__init__.py | 49 ++++ scattermoe/triton_implementation/kernels.py | 250 ++++++++++++++++++ .../triton_implementation/ops/__init__.py | 217 +++++++++++++++ .../ops/compileable_ops.py | 234 ++++++++++++++++ 5 files changed, 753 insertions(+), 2 deletions(-) create mode 100644 scattermoe/triton_implementation/__init__.py create mode 100644 scattermoe/triton_implementation/kernels.py create mode 100644 scattermoe/triton_implementation/ops/__init__.py create mode 100644 scattermoe/triton_implementation/ops/compileable_ops.py diff --git a/scattermoe/mlp.py b/scattermoe/mlp.py index d4761d6..93a317d 100644 --- a/scattermoe/mlp.py +++ b/scattermoe/mlp.py @@ -2,8 +2,9 @@ from torch import nn from torch.nn import functional as F -from . import kernels -from .parallel_experts import ParallelExperts +# from . import kernels +# from .parallel_experts import ParallelExperts +from triton_implementation import ParallelExperts class GLUMLP(nn.Module): def __init__( diff --git a/scattermoe/triton_implementation/__init__.py b/scattermoe/triton_implementation/__init__.py new file mode 100644 index 0000000..99085b7 --- /dev/null +++ b/scattermoe/triton_implementation/__init__.py @@ -0,0 +1,49 @@ +from typing import Callable + +import torch +import torch.nn as nn + +from .ops import padded_block_indices, scattered_experts + + +class ParallelExperts(nn.Module): + def __init__(self, num_experts, input_size, output_size) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + self.reset_parameters() + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def extra_repr(self): + return 'num_experts={}, input_size={}, output_size={}'.format( + self.num_experts, self.input_size, self.output_size) + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.02) + + def forward( + self, + inputs, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + ): + return scattered_experts( + inputs, + self.weight.permute(0, 2, 1), + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + grouped_in, + grouped_out, + ) + diff --git a/scattermoe/triton_implementation/kernels.py b/scattermoe/triton_implementation/kernels.py new file mode 100644 index 0000000..b59f1fd --- /dev/null +++ b/scattermoe/triton_implementation/kernels.py @@ -0,0 +1,250 @@ +import triton +import triton.language as tl + + +BLOCK_M = 128 + + +@triton.autotune( + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + key=["N", "K"], +) +@triton.jit +def scatter2scatter_triton_kernel( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + Y_ptr, + stride_ym, + stride_yn, + grouped_idx_ptr, + expert_idxs_ptr, + block_start_idx_ptr, + FAN_OUT, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + x_grouped, + y_grouped, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(block_start_idx_ptr + M_block_id) + + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(K, BLOCK_K) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for K_block_id in range(0, iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + + if no_n_mask or K_block_id < (iters - 1): + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +@triton.autotune( + configs=[ + # different block M and reducing stages + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4), + # keep 4 stages and keep two 64 block sizes + # - NOTE: these can get good performances for low M, but for large M the variation + # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + ], + key=["N", "K"], +) +@triton.jit +def groupXtY_triton_kernel( + DY_ptr, + stride_dym, + stride_dyk, + X_ptr, + stride_xm, + stride_xn, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + expert_offsets_ptr, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + + if no_k_mask: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + + if no_n_mask: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + + +@triton.autotune(configs=[triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4)], key=["K"]) +@triton.jit +def group_triton_kernel( + src_ptr, + stride_sn, + stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT, + tgt_ptr, + stride_tn, + stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + no_k_mask = K % BLOCK_K == 0 + + for i in range(0, iters): + if no_k_mask or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=mask) + + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/scattermoe/triton_implementation/ops/__init__.py b/scattermoe/triton_implementation/ops/__init__.py new file mode 100644 index 0000000..861d4e8 --- /dev/null +++ b/scattermoe/triton_implementation/ops/__init__.py @@ -0,0 +1,217 @@ +import torch + +from .compileable_ops import compileable_bincount, group, group_bwd_W, scatter2scatter + + +BLOCK_M = 128 +torch._dynamo.config.capture_scalar_outputs = True + + +def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): + # there is an overhead of launching a custom op so we only use the custom op when compiling + if torch.compiler.is_compiling(): + expert_counts = compileable_bincount(sorted_experts_idxs, k) + else: + expert_counts = sorted_experts_idxs.bincount(minlength=k) + + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + + block_idxs = torch.arange( + padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device + ).unsqueeze(1) + + block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) + expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + + return expanded_block_idxs, expert_boundaries_end + + +class _ScatteredExperts(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + ): + output = torch.empty(sorted_expert_idxs.size(0), expert_weights.size(-1), device=x.device, dtype=x.dtype) + + scatter2scatter( + X=x, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=output, + FAN_OUT=k, + x_grouped=grouped_in, + y_grouped=grouped_out, + ) + + if gates is None: + output_expanded = None + else: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = torch.bmm(gates.unsqueeze(1), output_expanded).squeeze(1) + + ctx.save_for_backward( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + output_expanded, + ) + + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + + return output + + @staticmethod + def backward(ctx, grad_out): + ( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + output_expanded, + ) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + + if gates is None: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + else: + # calculate gates gradient + d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + # print("expanded and grouping") + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later + + if grouped_out: + grouped_grad_out = grad_out + else: + group( + A=grad_out, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_grad_out, + coeff=gates_flat, + fan_out=gate_fan, + ) + + if grouped_in: + grouped_x = x + d_expanded_input = torch.empty( + sorted_expert_idxs.size(0), expert_weights.size(1), device=x.device, dtype=x.dtype + ) + else: + grouped_x = torch.empty(sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device) + group( + A=x, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_x, + fan_out=k, + ) + + d_expanded_input = grouped_x + + d_weights = torch.zeros( + expert_weights.size(0), + grouped_grad_out.size(-1), + grouped_x.size(-1), + device=grouped_grad_out.device, + dtype=grouped_grad_out.dtype, + ).permute(0, 2, 1) + + group_bwd_W( + DY=grouped_grad_out, + X=grouped_x, + expert_offsets=expert_offsets, + DW=d_weights, + E=expert_weights.size(0), + ) + + scatter2scatter( + X=grouped_grad_out, + W=expert_weights.permute(0, 2, 1), + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=d_expanded_input, + FAN_OUT=1, + x_grouped=True, + y_grouped=grouped_in, + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) + + # print("backward end.") + return ( + # x, expert_weights, k, + d_input, + d_weights, + None, + # sorted_expert_idxs, sorted_scattered_idxs, + None, + None, + # padded_block_idxs, expert_offsets, + None, + None, + # gates + d_gates, + None, + None, + ) + + +def scattered_experts( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, +): + return _ScatteredExperts.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + grouped_in, + grouped_out, + ) diff --git a/scattermoe/triton_implementation/ops/compileable_ops.py b/scattermoe/triton_implementation/ops/compileable_ops.py new file mode 100644 index 0000000..ac60af3 --- /dev/null +++ b/scattermoe/triton_implementation/ops/compileable_ops.py @@ -0,0 +1,234 @@ +import torch +import triton +import triton.language as tl + +from .....constants import LIBRARY_NAME +from ....utils import torch_custom_op +from ..kernels import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel + + +BLOCK_M = 128 +torch._dynamo.config.capture_scalar_outputs = True + + +# bincount is not compilable +@torch_custom_op(f"{LIBRARY_NAME}::bincount", mutates_args={}) +def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength) + + +@compileable_bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, dtype=torch.long, device=x.device) + + +def _scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * FAN_OUT + assert out.size(0) == sorted_expert_idxs.size(0) + assert out.size(1) == W.size(-1) + + grid = lambda meta: (padded_block_idxs.size(0) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + scatter2scatter_triton_kernel[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + out, + out.stride(0), + out.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + block_start_idx_ptr=padded_block_idxs, + FAN_OUT=FAN_OUT, + M=X.size(0), + K=X.size(1), + N=out.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::scatter2scatter", mutates_args={"out"}) +def _scatter2scatter_compileable( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + if torch.compiler.is_compiling(): + _scatter2scatter_compileable( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + else: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def _group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + grid = lambda meta: (E * triton.cdiv(meta["K"], meta["BLOCK_K"]), triton.cdiv(meta["N"], meta["BLOCK_N"])) + + groupXtY_triton_kernel[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + N=DY.size(-1), + K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group_bwd_W", mutates_args={"DW"}) +def _group_bwd_W_compileable( + DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int +) -> None: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + if torch.compiler.is_compiling(): + _group_bwd_W_compileable(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + else: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def _group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + + grid = lambda meta: (triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + group_triton_kernel[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + coeff is not None, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + out, + out.stride(0), + out.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group", mutates_args={"out"}) +def _group_compileable( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) + + +def group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + if torch.compiler.is_compiling(): + _group_compileable(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) + else: + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) From 12b800f1af406f79c595c2cf39078cb1b248a163 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 22:58:14 -0400 Subject: [PATCH 15/20] relative import. --- scattermoe/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scattermoe/mlp.py b/scattermoe/mlp.py index 93a317d..f153cfc 100644 --- a/scattermoe/mlp.py +++ b/scattermoe/mlp.py @@ -4,7 +4,7 @@ # from . import kernels # from .parallel_experts import ParallelExperts -from triton_implementation import ParallelExperts +from .triton_implementation import ParallelExperts class GLUMLP(nn.Module): def __init__( From cc00cc7bf39d8aef2086d9a450fa0a7217fa08ec Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 23:02:10 -0400 Subject: [PATCH 16/20] Modified call. --- scattermoe/mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scattermoe/mlp.py b/scattermoe/mlp.py index f153cfc..7b4835f 100644 --- a/scattermoe/mlp.py +++ b/scattermoe/mlp.py @@ -4,7 +4,7 @@ # from . import kernels # from .parallel_experts import ParallelExperts -from .triton_implementation import ParallelExperts +from .triton_implementation import ParallelExperts, padded_block_indices class GLUMLP(nn.Module): def __init__( @@ -33,7 +33,7 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te x = x.view(-1, x_shape[-1]) with torch.no_grad(): sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten()) - padded_block_idxs, expert_offsets = kernels.padded_block_indices(sorted_expert_idxs, self.num_experts) + padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts) h, gates = self.experts( x, self.top_k, From 6dfc2dcae02d12b219852e8b2b5e0fb44d526bee Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Mon, 7 Oct 2024 03:49:40 +0000 Subject: [PATCH 17/20] Tests pass, compile not working. --- scattermoe/mlp.py | 2 +- scattermoe/triton_implementation/ops/compileable_ops.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scattermoe/mlp.py b/scattermoe/mlp.py index 7b4835f..b7c984e 100644 --- a/scattermoe/mlp.py +++ b/scattermoe/mlp.py @@ -79,7 +79,7 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te x = x.view(-1, x_shape[-1]) with torch.no_grad(): sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten()) - padded_block_idxs, expert_offsets = kernels.padded_block_indices(sorted_expert_idxs, self.num_experts) + padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts) h = self.experts( x, self.top_k, diff --git a/scattermoe/triton_implementation/ops/compileable_ops.py b/scattermoe/triton_implementation/ops/compileable_ops.py index ac60af3..0ad23ac 100644 --- a/scattermoe/triton_implementation/ops/compileable_ops.py +++ b/scattermoe/triton_implementation/ops/compileable_ops.py @@ -2,11 +2,11 @@ import triton import triton.language as tl -from .....constants import LIBRARY_NAME -from ....utils import torch_custom_op +from torch.library import custom_op as torch_custom_op from ..kernels import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel +LIBRARY_NAME = "scattermoe" BLOCK_M = 128 torch._dynamo.config.capture_scalar_outputs = True From 8d028251b429dc0e4406cfc080a2ad2f22bc3578 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 6 Oct 2024 23:53:03 -0400 Subject: [PATCH 18/20] custom_op wrapper. --- scattermoe/__init__.py | 2 + .../ops/compileable_ops.py | 2 +- scattermoe/triton_implementation/utils.py | 39 +++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 scattermoe/triton_implementation/utils.py diff --git a/scattermoe/__init__.py b/scattermoe/__init__.py index a82fe80..098c3fc 100644 --- a/scattermoe/__init__.py +++ b/scattermoe/__init__.py @@ -1,3 +1,5 @@ from . import kernels from . import parallel_experts from . import mlp +from triton_implementation import padded_block_indices +from parallel_experts import ParallelExperts diff --git a/scattermoe/triton_implementation/ops/compileable_ops.py b/scattermoe/triton_implementation/ops/compileable_ops.py index 0ad23ac..23f547b 100644 --- a/scattermoe/triton_implementation/ops/compileable_ops.py +++ b/scattermoe/triton_implementation/ops/compileable_ops.py @@ -2,8 +2,8 @@ import triton import triton.language as tl -from torch.library import custom_op as torch_custom_op from ..kernels import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel +from ..utils import torch_custom_op LIBRARY_NAME = "scattermoe" diff --git a/scattermoe/triton_implementation/utils.py b/scattermoe/triton_implementation/utils.py new file mode 100644 index 0000000..6270c11 --- /dev/null +++ b/scattermoe/triton_implementation/utils.py @@ -0,0 +1,39 @@ +from typing import Any, Callable, Iterable + +import torch + + +try: + from torch.library import custom_op + + _IS_CUSTOM_OP_IN_PYTORCH = True +except: + _IS_CUSTOM_OP_IN_PYTORCH = False + + +class _IdentityOp: + def __init__(self, fn: Callable) -> None: + self.fn = fn + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.fn(*args, **kwargs) + + def register_fake(self, fn: Callable) -> Callable: + return fn + + +def torch_custom_op( + name: str, + fn: Callable | None = None, + /, + *, + mutates_args: str | Iterable[str], + device_types: torch.device = None, + schema: str | None = None, +) -> Callable | _IdentityOp: + if _IS_CUSTOM_OP_IN_PYTORCH: + op = custom_op(name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema) + else: + op = _IdentityOp if fn is None else _IdentityOp(fn) + + return op \ No newline at end of file From 7a317de3e3c0068624cded71d9065a7cf3625f7a Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Mon, 7 Oct 2024 04:07:49 +0000 Subject: [PATCH 19/20] Works. --- scattermoe/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scattermoe/__init__.py b/scattermoe/__init__.py index 098c3fc..ca2a509 100644 --- a/scattermoe/__init__.py +++ b/scattermoe/__init__.py @@ -1,5 +1,5 @@ from . import kernels from . import parallel_experts from . import mlp -from triton_implementation import padded_block_indices -from parallel_experts import ParallelExperts +from .triton_implementation import padded_block_indices +from .parallel_experts import ParallelExperts From a868f186923359c082b72a18db13eee43f52e44a Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Tue, 8 Oct 2024 15:29:55 -0400 Subject: [PATCH 20/20] Refactored expert block computation. --- scattermoe/triton_implementation/kernels.py | 63 +++++++++++++-------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/scattermoe/triton_implementation/kernels.py b/scattermoe/triton_implementation/kernels.py index b59f1fd..cb0e60e 100644 --- a/scattermoe/triton_implementation/kernels.py +++ b/scattermoe/triton_implementation/kernels.py @@ -11,16 +11,9 @@ ) @triton.jit def scatter2scatter_triton_kernel( - X_ptr, - stride_xm, - stride_xk, - W_ptr, - stride_we, - stride_wk, - stride_wn, - Y_ptr, - stride_ym, - stride_yn, + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + Y_ptr, stride_ym, stride_yn, grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr, @@ -47,34 +40,57 @@ def scatter2scatter_triton_kernel( M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) E_idx = tl.min(E_idxs) + + E_mask, M_out_idx, N_block, N_mask, acc = compute_expert_block( + E_idx, E_idxs, + M_block, N_block_id, + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + grouped_idx_ptr, + FAN_OUT, K, N, + acc, + allow_tf32, + no_k_mask, no_n_mask, + x_grouped, y_grouped, ACC_TYPE, BLOCK_N, BLOCK_K + ) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +def compute_expert_block( + E_idx, E_idxs, + M_block, N_block_id, + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + grouped_idx_ptr, + FAN_OUT, K, N, + acc, + allow_tf32, + no_k_mask, no_n_mask, + x_grouped, y_grouped, + ACC_TYPE, BLOCK_N, BLOCK_K): E_mask = E_idxs == E_idx M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) - if x_grouped: M_in_idx = M_block else: M_in_idx = M_idx // FAN_OUT - if y_grouped: M_out_idx = M_block else: M_out_idx = M_idx - K_block = tl.arange(0, BLOCK_K) - N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) N_mask = N_block < N - X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) iters = tl.cdiv(K, BLOCK_K) - - no_k_mask = K % BLOCK_K == 0 - no_n_mask = N % BLOCK_N == 0 - for K_block_id in range(0, iters): if no_k_mask: x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) @@ -91,9 +107,7 @@ def scatter2scatter_triton_kernel( X_blk_ptrs += BLOCK_K * stride_xk W_blk_ptrs += BLOCK_K * stride_wk acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) - - Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) - tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + return E_mask, M_out_idx, N_block, N_mask, acc @triton.autotune( @@ -165,6 +179,7 @@ def groupXtY_triton_kernel( dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) no_k_mask = K % BLOCK_K == 0