diff --git a/scattermoe/__init__.py b/scattermoe/__init__.py index a82fe80..ca2a509 100644 --- a/scattermoe/__init__.py +++ b/scattermoe/__init__.py @@ -1,3 +1,5 @@ from . import kernels from . import parallel_experts from . import mlp +from .triton_implementation import padded_block_indices +from .parallel_experts import ParallelExperts diff --git a/scattermoe/kernels/__init__.py b/scattermoe/kernels/__init__.py index 14c0c07..0267cf8 100644 --- a/scattermoe/kernels/__init__.py +++ b/scattermoe/kernels/__init__.py @@ -1,2 +1,28 @@ -from . import ops -from . import single \ No newline at end of file +import torch +from . import compileable_ops as ops +from . import single + +BLOCK_M = ops.BLOCK_M + +def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): + # there is an overhead of launching a custom op so we only use the custom op when compiling + if torch.compiler.is_compiling(): + expert_counts = compileable_bincount(sorted_experts_idxs, k) + else: + expert_counts = sorted_experts_idxs.bincount(minlength=k) + + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + + block_idxs = torch.arange( + padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device + ).unsqueeze(1) + + block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) + expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + + return expanded_block_idxs, expert_boundaries_end diff --git a/scattermoe/kernels/compileable_ops.py b/scattermoe/kernels/compileable_ops.py new file mode 100644 index 0000000..41d5fd0 --- /dev/null +++ b/scattermoe/kernels/compileable_ops.py @@ -0,0 +1,235 @@ +import torch +import triton +import triton.language as tl + +from torch.library import custom_op as torch_custom_op +from .triton import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel + + +LIBRARY_NAME = "scattermoe" +BLOCK_M = 128 +torch._dynamo.config.capture_scalar_outputs = True +ALLOW_TF32 = False + + +# bincount is not compilable +@torch_custom_op(f"{LIBRARY_NAME}::bincount", mutates_args={}) +def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength) + + +@compileable_bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, dtype=torch.long, device=x.device) + + +def _scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * FAN_OUT + assert out.size(0) == sorted_expert_idxs.size(0) + assert out.size(1) == W.size(-1) + + grid = lambda meta: (padded_block_idxs.size(0) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + scatter2scatter_triton_kernel[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + out, + out.stride(0), + out.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + block_start_idx_ptr=padded_block_idxs, + FAN_OUT=FAN_OUT, + M=X.size(0), + K=X.size(1), + N=out.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32 and ALLOW_TF32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::scatter2scatter", mutates_args={"out"}) +def _scatter2scatter_compileable( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + if torch.compiler.is_compiling(): + _scatter2scatter_compileable( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + else: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def _group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + grid = lambda meta: (E * triton.cdiv(meta["K"], meta["BLOCK_K"]), triton.cdiv(meta["N"], meta["BLOCK_N"])) + + groupXtY_triton_kernel[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + N=DY.size(-1), + K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32 and ALLOW_TF32, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group_bwd_W", mutates_args={"DW"}) +def _group_bwd_W_compileable( + DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int +) -> None: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + if torch.compiler.is_compiling(): + _group_bwd_W_compileable(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + else: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def _group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + + grid = lambda meta: (triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + group_triton_kernel[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + coeff is not None, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + out, + out.stride(0), + out.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group", mutates_args={"out"}) +def _group_compileable( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) + + +def group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + if torch.compiler.is_compiling(): + _group_compileable(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) + else: + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) diff --git a/scattermoe/kernels/ops.py b/scattermoe/kernels/ops.py deleted file mode 100644 index 2386b6a..0000000 --- a/scattermoe/kernels/ops.py +++ /dev/null @@ -1,389 +0,0 @@ -import torch -import triton -import triton.language as tl -from torch.nn import functional as F - -BLOCK_M = 128 -ALLOW_TF32 = False - -@torch.library.custom_op("scattermoe::bincount", mutates_args={}) -def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: - return x.bincount(minlength=minlength) - -@compileable_bincount.register_fake -def _(x: torch.Tensor, minlength: int) -> torch.Tensor: - return torch.empty(minlength, dtype=torch.long, device=x.device) - -@torch.compile -def flatten_and_sort(expert_idxs:torch.Tensor): - flattened_expert_idxs = expert_idxs.flatten() - sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) - return sorted_expert_idxs, sorted_scattered_idxs - -@torch.compile -def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) : - expert_counts = compileable_bincount(sorted_experts_idxs, minlength=k) - padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 - padded_expert_block_end = padded_block_counts.cumsum(-1) - expert_boundaries_end = expert_counts.cumsum(-1) - expert_boundaries_start = expert_boundaries_end - expert_counts - padded_expert_block_start = padded_expert_block_end - padded_block_counts - block_idxs = torch.arange(padded_expert_block_end[-1], - dtype=sorted_experts_idxs.dtype, - device=sorted_experts_idxs.device) - block_mask = ( - (block_idxs[:, None] < padded_expert_block_start) | - (block_idxs[:, None] >= padded_expert_block_end) - ) - expanded_block_idxs = ( - N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) + - expert_boundaries_start - ) - expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) - return expanded_block_idxs, expert_boundaries_end - - - -def _scatter2scatter_configs(): - return [ - triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4), - ] - -@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], ) -@triton.heuristics({ - "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, - "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, -}) -@triton.jit -def _scatter2scatter( - X_ptr, stride_xm, stride_xk, - W_ptr, stride_we, stride_wk, stride_wn, - Y_ptr, stride_ym, stride_yn, - grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr, - FAN_OUT: tl.constexpr, - M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - OUT_M, - allow_tf32: tl.constexpr, - x_grouped: tl.constexpr, y_grouped: tl.constexpr, - NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr -): - pid = tl.program_id(axis=0) - - N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) - M_block_id = pid // N_BLOCK_COUNT - N_block_id = pid % N_BLOCK_COUNT - M_range = tl.arange(0, BLOCK_M) - block_start_idx = tl.load(block_start_idx_ptr + M_block_id) - # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) - M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) - E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) - E_idx = tl.min(E_idxs) - E_mask = E_idxs == E_idx - M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) - if x_grouped: - M_in_idx = M_block - else: - M_in_idx = M_idx // FAN_OUT - - if y_grouped: - M_out_idx = M_block - else: - M_out_idx = M_idx - - K_block = tl.arange(0, BLOCK_K) - - N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - N_mask = N_block < N - # N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) - # N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - - X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk - W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - iters = tl.cdiv(K, BLOCK_K) - for K_block_id in range(0, iters): - if NO_K_MASK: - x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) - if NO_N_MASK or K_block_id < (iters - 1): - w = tl.load(W_blk_ptrs) - else: - w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) - else: - K_mask = (K_block_id * BLOCK_K + K_block) < K - x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) - w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) - X_blk_ptrs += BLOCK_K * stride_xk - W_blk_ptrs += BLOCK_K * stride_wk - acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) - - Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) - tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) - -def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k, - padded_block_idxs, x_grouped=False, y_grouped=False, - out=None): - assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) - assert sorted_scattered_idxs.size(0) == X.size(0) * k - # Pre-kernel setup - x_dim = X.size(-1) - y_dim = W.size(-1) - L_scattered = sorted_expert_idxs.size(0) - if out is None: - O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) - else: - assert out.size(0) == L_scattered and out.size(1) == y_dim - O = out - - # with torch.cuda.device(X.device): - scatter2scatter_compileable(O, W, X, k, padded_block_idxs, sorted_expert_idxs, sorted_scattered_idxs, - x_grouped, y_grouped) - return O - - -@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"O"}) -def scatter2scatter_compileable( - O: torch.Tensor, - W: torch.Tensor, - X: torch.Tensor, - k: int, - padded_block_idxs: torch.Tensor, - sorted_expert_idxs: torch.Tensor, - sorted_scattered_idxs: torch.Tensor, - x_grouped: bool, y_grouped: bool) -> None: - def grid(META): - grid_num = ( - padded_block_idxs.size(0) * - triton.cdiv(META['N'], META['BLOCK_N']), - ) - return grid_num - - _scatter2scatter[grid]( - # X_ptr, stride_xm, stride_xk, - X, X.stride(0), X.stride(1), - # W_ptr, stride_we, stride_wk, stride_wn, - W, W.stride(0), W.stride(1), W.stride(2), - # Y_ptr, stride_ym, stride_yn, - O, O.stride(0), O.stride(1), - grouped_idx_ptr=sorted_scattered_idxs, - expert_idxs_ptr=sorted_expert_idxs, - block_start_idx_ptr=padded_block_idxs, - FAN_OUT=k, - M=X.size(0), - K=X.size(1), - N=O.size(1), E=W.size(0), - BLOCK_M=BLOCK_M, - ACC_TYPE=tl.float32, - OUT_M=O.size(0), - allow_tf32=ALLOW_TF32, - x_grouped=x_grouped, y_grouped=y_grouped, - ) - - -def _config_XtY(): - return [ - triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4), - ] - -def group_bwd_W(DY, X, expert_offsets, E): - DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) - DW = DWt.permute(0, 2, 1) - groupXtY_compileable(E, DW, DY, X, expert_offsets) - return DW - - -@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"}) -def groupXtY_compileable( - E: int, - DW: torch.Tensor, - DY: torch.Tensor, - X: torch.Tensor, - expert_offsets: torch.Tensor) -> None: - def grid(META): - grid = ( - E * triton.cdiv(META['K'], META['BLOCK_K']), - triton.cdiv(META['N'], META['BLOCK_N']), - ) - return grid - - _groupXtY[grid]( - # DY_ptr, stride_dym, stride_dyk, - DY, DY.stride(0), DY.stride(1), - # X_ptr, stride_xm, stride_xn, - X, X.stride(0), X.stride(1), - # DW_ptr, stride_dwe, stride_dwk, stride_dwn, - DW, DW.stride(0), DW.stride(1), DW.stride(2), - # expert_offsets_ptr, - expert_offsets, - # K: tl.constexpr, N: tl.constexpr, - M=DY.size(0), N=DY.size(-1), K=X.size(-1), - # ACC_TYPE: tl.constexpr, - ACC_TYPE=tl.float32, - allow_tf32=True - ) - - -@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], ) -@triton.heuristics({ - "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, - "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, -}) -@triton.jit -def _groupXtY( - DY_ptr, stride_dym, stride_dyk, - X_ptr, stride_xm, stride_xn, - DW_ptr, stride_dwe, stride_dwk, stride_dwn, - expert_offsets_ptr, - M, K: tl.constexpr, N: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - allow_tf32: tl.constexpr, - NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr -): - pid0 = tl.program_id(axis=0) - pid1 = tl.program_id(axis=1) - num0 = tl.num_programs(0) - num1 = tl.num_programs(1) - # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) - pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) - - K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) - E_idx = pid0 // K_BLOCK_COUNT - K_block_id = pid0 % K_BLOCK_COUNT - N_block_id = pid1 - - if E_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) - end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) - - if end_idx > start_idx: - M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) - - K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) - K_mask = K_block < K - K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) - - N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - N_mask = N_block < N - N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) - - M_idxs = M_block - xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm - dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk - - acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) - iters = tl.cdiv(end_idx - start_idx, BLOCK_M) - for i in range(0, iters): - M_mask = (i * BLOCK_M + M_block) < end_idx - if NO_K_MASK: - xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) - else: - xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) - if NO_N_MASK: - dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) - else: - dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) - # acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) - xt_blk_ptrs += BLOCK_M * stride_xm - dy_blk_ptrs += BLOCK_M * stride_dym - acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) - - - - DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn - acc = acc.to(DW_blk_ptrs.dtype.element_ty) - tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) - - -def _config_grouping(): - return [ - triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4), - # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), - # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4), - ] - -def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): - N = sorted_expert_idxs.size(0) - K = A.size(1) - assert A.size(0) * fan_out == N - if out is not None: - Y = out - else: - Y = torch.empty((N, K), dtype=A.dtype, device=A.device) - group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs) - return Y - - -@torch.library.custom_op("scattermoe::group", mutates_args={"Y"}) -def group_compileable( - A: torch.Tensor, - K: int, - N: int, - Y: torch.Tensor, - coeff: torch.Tensor, has_coeff: bool, - fan_out: int, - sorted_expert_idxs: torch.Tensor) -> None: - def grid(META): - grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),) - return grid_num - _group[grid]( - # A_ptr, stride_an, stride_ai, - A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out, - # Y_ptr, stride_yn, stride_yk, - Y, Y.stride(0), Y.stride(1), - # grouped_idx_ptr, - sorted_expert_idxs, - # N: tl.constexpr, K: tl.constexpr, - N, K - ) - - -@triton.autotune(configs=_config_grouping(), key=['K']) -@triton.heuristics({ - "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0 -}) -@triton.jit -def _group( - src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr, - tgt_ptr, stride_tn, stride_ti, - grouped_idx_ptr, - N, K: tl.constexpr, - BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NO_K_MASK: tl.constexpr -): - pid = tl.program_id(axis=0) - - N_block_id = pid - N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) - N_mask = N_blk < N - N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) - N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) - - K_blk = tl.arange(0, BLOCK_K) - src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk - tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti - - if has_coeff: - c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] - - iters = tl.cdiv(K, BLOCK_K) - for i in range(0, iters): - if NO_K_MASK or i < iters - 1: - block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) - if has_coeff: - block *= c - tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) - - else: - K_mask = (i * BLOCK_K + K_blk) < K - mask = N_mask[:, None] & K_mask[None, :] - block = tl.load(src_blk_ptrs, mask=mask) - if has_coeff: - block *= c - tl.store(tgt_blk_ptrs, block, mask=mask) - src_blk_ptrs += BLOCK_K * stride_sk - tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/scattermoe/kernels/triton.py b/scattermoe/kernels/triton.py new file mode 100644 index 0000000..89274ce --- /dev/null +++ b/scattermoe/kernels/triton.py @@ -0,0 +1,247 @@ +import triton +import triton.language as tl + + +@triton.autotune( + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + key=["N", "K"], +) +@triton.jit +def scatter2scatter_triton_kernel( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + Y_ptr, + stride_ym, + stride_yn, + grouped_idx_ptr, + expert_idxs_ptr, + block_start_idx_ptr, + FAN_OUT, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + x_grouped, + y_grouped, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(block_start_idx_ptr + M_block_id) + + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(K, BLOCK_K) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for K_block_id in range(0, iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + + if no_n_mask or K_block_id < (iters - 1): + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +@triton.autotune( + configs=[ + # different block M and reducing stages + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4), + # keep 4 stages and keep two 64 block sizes + # - NOTE: these can get good performances for low M, but for large M the variation + # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + ], + key=["N", "K"], +) +@triton.jit +def groupXtY_triton_kernel( + DY_ptr, + stride_dym, + stride_dyk, + X_ptr, + stride_xm, + stride_xn, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + expert_offsets_ptr, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + + if no_k_mask: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + + if no_n_mask: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + + +@triton.autotune(configs=[triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4)], key=["K"]) +@triton.jit +def group_triton_kernel( + src_ptr, + stride_sn, + stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT, + tgt_ptr, + stride_tn, + stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + no_k_mask = K % BLOCK_K == 0 + + for i in range(0, iters): + if no_k_mask or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=mask) + + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/scattermoe/mlp.py b/scattermoe/mlp.py index ef6d0a6..b7c984e 100644 --- a/scattermoe/mlp.py +++ b/scattermoe/mlp.py @@ -2,8 +2,9 @@ from torch import nn from torch.nn import functional as F -from . import kernels -from .parallel_experts import ParallelExperts +# from . import kernels +# from .parallel_experts import ParallelExperts +from .triton_implementation import ParallelExperts, padded_block_indices class GLUMLP(nn.Module): def __init__( @@ -31,8 +32,8 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te x_shape = x.size() x = x.view(-1, x_shape[-1]) with torch.no_grad(): - sorted_expert_idxs, sorted_scattered_idxs = kernels.ops.flatten_and_sort(expert_idxs) - padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts) + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten()) + padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts) h, gates = self.experts( x, self.top_k, @@ -77,8 +78,8 @@ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Te x_shape = x.size() x = x.view(-1, x_shape[-1]) with torch.no_grad(): - sorted_expert_idxs, sorted_scattered_idxs = kernels.ops.flatten_and_sort(expert_idxs) - padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts) + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(expert_idxs.flatten()) + padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs, self.num_experts) h = self.experts( x, self.top_k, diff --git a/scattermoe/parallel_experts.py b/scattermoe/parallel_experts.py index 6621899..2540504 100644 --- a/scattermoe/parallel_experts.py +++ b/scattermoe/parallel_experts.py @@ -11,21 +11,21 @@ def forward( gates=None, grouped_in=False, grouped_out=False, ): with torch.device(x.device): - output = kernels.ops.scatter2scatter( + output = torch.empty((sorted_expert_idxs.size(0), expert_weights.size(-1)), + device=x.device, dtype=x.dtype) + kernels.ops.scatter2scatter( X=x, W=expert_weights, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, padded_block_idxs=padded_block_idxs, - k=k, x_grouped=grouped_in, y_grouped=grouped_out + out=output, + FAN_OUT=k, x_grouped=grouped_in, y_grouped=grouped_out ) - if gates is not None: - output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) - output = torch.bmm( - gates[:, None, :], - output_expanded - ).squeeze(1) - else: + if gates is None: output_expanded = None + else: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = torch.bmm(gates.unsqueeze(1), output_expanded).squeeze(1) ctx.save_for_backward( x, expert_weights, @@ -50,46 +50,72 @@ def backward(ctx, grad_out): grouped_in = ctx.grouped_in grouped_out = ctx.grouped_out # print("backward") - with torch.device(grad_out.device): - if gates is not None: - # calculate gates gradient - d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) - gates_flat = gates.flatten() - gate_fan = gates.size(1) - # print("expanded and grouping") - grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later - else: + with torch.device(x.device): + if gates is None: d_gates = None gates_flat = None gate_fan = 1 - grouped_grad_out = None + else: + # calculate gates gradient + d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) if grouped_out: grouped_grad_out = grad_out else: - grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs, - fan_out=gate_fan, coeff=gates_flat, - out=grouped_grad_out) + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later + kernels.ops.group( + A=grad_out, + sorted_expert_idxs=sorted_expert_idxs, + out=grouped_grad_out, + coeff=gates_flat, + fan_out=gate_fan, + ) if grouped_in: grouped_x = x - d_expanded_input = None + d_expanded_input = torch.empty( + (sorted_expert_idxs.size(0), expert_weights.size(1)), + device=x.device, dtype=x.dtype) else: - grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k) + grouped_x = torch.empty( + (sorted_scattered_idxs.size(0), x.size(1)), + dtype=x.dtype, device=x.device + ) + kernels.ops.group( + A=x, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_x, + fan_out=k + ) d_expanded_input = grouped_x - d_weights = kernels.ops.group_bwd_W( - DY=grouped_grad_out, X=grouped_x, + + d_weights = torch.zeros( + expert_weights.size(0), + grouped_grad_out.size(-1), + grouped_x.size(-1), + device=grouped_grad_out.device, + dtype=grouped_grad_out.dtype, + ).permute(0, 2, 1) + + kernels.ops.group_bwd_W( + DY=grouped_grad_out, + X=grouped_x, expert_offsets=expert_offsets, + DW=d_weights, E=expert_weights.size(0) ) - d_expanded_input = kernels.ops.scatter2scatter( - X=grouped_grad_out, x_grouped=True, + + kernels.ops.scatter2scatter( + X=grouped_grad_out, W=expert_weights.permute(0, 2, 1), - padded_block_idxs=padded_block_idxs, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - k=1, + padded_block_idxs=padded_block_idxs, + out=d_expanded_input, # Reuse grouped_x buffer + FAN_OUT=1, + x_grouped=True, y_grouped=grouped_in, - out=d_expanded_input # Reuse grouped_x buffer ) if k == 1: diff --git a/scattermoe/triton_implementation/__init__.py b/scattermoe/triton_implementation/__init__.py new file mode 100644 index 0000000..99085b7 --- /dev/null +++ b/scattermoe/triton_implementation/__init__.py @@ -0,0 +1,49 @@ +from typing import Callable + +import torch +import torch.nn as nn + +from .ops import padded_block_indices, scattered_experts + + +class ParallelExperts(nn.Module): + def __init__(self, num_experts, input_size, output_size) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + self.reset_parameters() + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def extra_repr(self): + return 'num_experts={}, input_size={}, output_size={}'.format( + self.num_experts, self.input_size, self.output_size) + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.02) + + def forward( + self, + inputs, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + ): + return scattered_experts( + inputs, + self.weight.permute(0, 2, 1), + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + grouped_in, + grouped_out, + ) + diff --git a/scattermoe/triton_implementation/kernels.py b/scattermoe/triton_implementation/kernels.py new file mode 100644 index 0000000..cb0e60e --- /dev/null +++ b/scattermoe/triton_implementation/kernels.py @@ -0,0 +1,265 @@ +import triton +import triton.language as tl + + +BLOCK_M = 128 + + +@triton.autotune( + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + key=["N", "K"], +) +@triton.jit +def scatter2scatter_triton_kernel( + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + Y_ptr, stride_ym, stride_yn, + grouped_idx_ptr, + expert_idxs_ptr, + block_start_idx_ptr, + FAN_OUT, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + x_grouped, + y_grouped, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(block_start_idx_ptr + M_block_id) + + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + E_idx = tl.min(E_idxs) + + E_mask, M_out_idx, N_block, N_mask, acc = compute_expert_block( + E_idx, E_idxs, + M_block, N_block_id, + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + grouped_idx_ptr, + FAN_OUT, K, N, + acc, + allow_tf32, + no_k_mask, no_n_mask, + x_grouped, y_grouped, ACC_TYPE, BLOCK_N, BLOCK_K + ) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +def compute_expert_block( + E_idx, E_idxs, + M_block, N_block_id, + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + grouped_idx_ptr, + FAN_OUT, K, N, + acc, + allow_tf32, + no_k_mask, no_n_mask, + x_grouped, y_grouped, + ACC_TYPE, BLOCK_N, BLOCK_K): + E_mask = E_idxs == E_idx + M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0) + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + K_block = tl.arange(0, BLOCK_K) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we + iters = tl.cdiv(K, BLOCK_K) + for K_block_id in range(0, iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + + if no_n_mask or K_block_id < (iters - 1): + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE) + return E_mask, M_out_idx, N_block, N_mask, acc + + +@triton.autotune( + configs=[ + # different block M and reducing stages + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4), + # keep 4 stages and keep two 64 block sizes + # - NOTE: these can get good performances for low M, but for large M the variation + # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 128, 'BLOCK_M': 64}, num_stages=4, num_warps=4), + ], + key=["N", "K"], +) +@triton.jit +def groupXtY_triton_kernel( + DY_ptr, + stride_dym, + stride_dyk, + X_ptr, + stride_xm, + stride_xn, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + expert_offsets_ptr, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + + if no_k_mask: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + + if no_n_mask: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + + +@triton.autotune(configs=[triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4)], key=["K"]) +@triton.jit +def group_triton_kernel( + src_ptr, + stride_sn, + stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT, + tgt_ptr, + stride_tn, + stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + no_k_mask = K % BLOCK_K == 0 + + for i in range(0, iters): + if no_k_mask or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=mask) + + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/scattermoe/triton_implementation/ops/__init__.py b/scattermoe/triton_implementation/ops/__init__.py new file mode 100644 index 0000000..861d4e8 --- /dev/null +++ b/scattermoe/triton_implementation/ops/__init__.py @@ -0,0 +1,217 @@ +import torch + +from .compileable_ops import compileable_bincount, group, group_bwd_W, scatter2scatter + + +BLOCK_M = 128 +torch._dynamo.config.capture_scalar_outputs = True + + +def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M): + # there is an overhead of launching a custom op so we only use the custom op when compiling + if torch.compiler.is_compiling(): + expert_counts = compileable_bincount(sorted_experts_idxs, k) + else: + expert_counts = sorted_experts_idxs.bincount(minlength=k) + + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + + block_idxs = torch.arange( + padded_expert_block_end[-1], dtype=sorted_experts_idxs.dtype, device=sorted_experts_idxs.device + ).unsqueeze(1) + + block_mask = (block_idxs < padded_expert_block_start) | (block_idxs >= padded_expert_block_end) + expanded_block_idxs = N_BLOCK_SIZE * (block_idxs - padded_expert_block_start) + expert_boundaries_start + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + + return expanded_block_idxs, expert_boundaries_end + + +class _ScatteredExperts(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + ): + output = torch.empty(sorted_expert_idxs.size(0), expert_weights.size(-1), device=x.device, dtype=x.dtype) + + scatter2scatter( + X=x, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=output, + FAN_OUT=k, + x_grouped=grouped_in, + y_grouped=grouped_out, + ) + + if gates is None: + output_expanded = None + else: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = torch.bmm(gates.unsqueeze(1), output_expanded).squeeze(1) + + ctx.save_for_backward( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + output_expanded, + ) + + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + + return output + + @staticmethod + def backward(ctx, grad_out): + ( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + output_expanded, + ) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + + if gates is None: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + else: + # calculate gates gradient + d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + # print("expanded and grouping") + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later + + if grouped_out: + grouped_grad_out = grad_out + else: + group( + A=grad_out, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_grad_out, + coeff=gates_flat, + fan_out=gate_fan, + ) + + if grouped_in: + grouped_x = x + d_expanded_input = torch.empty( + sorted_expert_idxs.size(0), expert_weights.size(1), device=x.device, dtype=x.dtype + ) + else: + grouped_x = torch.empty(sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device) + group( + A=x, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_x, + fan_out=k, + ) + + d_expanded_input = grouped_x + + d_weights = torch.zeros( + expert_weights.size(0), + grouped_grad_out.size(-1), + grouped_x.size(-1), + device=grouped_grad_out.device, + dtype=grouped_grad_out.dtype, + ).permute(0, 2, 1) + + group_bwd_W( + DY=grouped_grad_out, + X=grouped_x, + expert_offsets=expert_offsets, + DW=d_weights, + E=expert_weights.size(0), + ) + + scatter2scatter( + X=grouped_grad_out, + W=expert_weights.permute(0, 2, 1), + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=d_expanded_input, + FAN_OUT=1, + x_grouped=True, + y_grouped=grouped_in, + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) + + # print("backward end.") + return ( + # x, expert_weights, k, + d_input, + d_weights, + None, + # sorted_expert_idxs, sorted_scattered_idxs, + None, + None, + # padded_block_idxs, expert_offsets, + None, + None, + # gates + d_gates, + None, + None, + ) + + +def scattered_experts( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, +): + return _ScatteredExperts.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates, + grouped_in, + grouped_out, + ) diff --git a/scattermoe/triton_implementation/ops/compileable_ops.py b/scattermoe/triton_implementation/ops/compileable_ops.py new file mode 100644 index 0000000..23f547b --- /dev/null +++ b/scattermoe/triton_implementation/ops/compileable_ops.py @@ -0,0 +1,234 @@ +import torch +import triton +import triton.language as tl + +from ..kernels import group_triton_kernel, groupXtY_triton_kernel, scatter2scatter_triton_kernel +from ..utils import torch_custom_op + + +LIBRARY_NAME = "scattermoe" +BLOCK_M = 128 +torch._dynamo.config.capture_scalar_outputs = True + + +# bincount is not compilable +@torch_custom_op(f"{LIBRARY_NAME}::bincount", mutates_args={}) +def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength) + + +@compileable_bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, dtype=torch.long, device=x.device) + + +def _scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * FAN_OUT + assert out.size(0) == sorted_expert_idxs.size(0) + assert out.size(1) == W.size(-1) + + grid = lambda meta: (padded_block_idxs.size(0) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + scatter2scatter_triton_kernel[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + out, + out.stride(0), + out.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + block_start_idx_ptr=padded_block_idxs, + FAN_OUT=FAN_OUT, + M=X.size(0), + K=X.size(1), + N=out.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::scatter2scatter", mutates_args={"out"}) +def _scatter2scatter_compileable( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + padded_block_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + if torch.compiler.is_compiling(): + _scatter2scatter_compileable( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + else: + _scatter2scatter( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + out=out, + FAN_OUT=FAN_OUT, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def _group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + grid = lambda meta: (E * triton.cdiv(meta["K"], meta["BLOCK_K"]), triton.cdiv(meta["N"], meta["BLOCK_N"])) + + groupXtY_triton_kernel[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + N=DY.size(-1), + K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=torch.backends.cudnn.allow_tf32, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group_bwd_W", mutates_args={"DW"}) +def _group_bwd_W_compileable( + DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int +) -> None: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + if torch.compiler.is_compiling(): + _group_bwd_W_compileable(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + else: + _group_bwd_W(DY=DY, X=X, expert_offsets=expert_offsets, DW=DW, E=E) + + +def _group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + + grid = lambda meta: (triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + group_triton_kernel[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + coeff is not None, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + out, + out.stride(0), + out.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) + + +# custom op is needed because of https://github.com/pytorch/pytorch/issues/136394 +@torch_custom_op(f"{LIBRARY_NAME}::group", mutates_args={"out"}) +def _group_compileable( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) + + +def group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + if torch.compiler.is_compiling(): + _group_compileable(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) + else: + _group(A=A, sorted_expert_idxs=sorted_expert_idxs, out=out, coeff=coeff, fan_out=fan_out) diff --git a/scattermoe/triton_implementation/utils.py b/scattermoe/triton_implementation/utils.py new file mode 100644 index 0000000..6270c11 --- /dev/null +++ b/scattermoe/triton_implementation/utils.py @@ -0,0 +1,39 @@ +from typing import Any, Callable, Iterable + +import torch + + +try: + from torch.library import custom_op + + _IS_CUSTOM_OP_IN_PYTORCH = True +except: + _IS_CUSTOM_OP_IN_PYTORCH = False + + +class _IdentityOp: + def __init__(self, fn: Callable) -> None: + self.fn = fn + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.fn(*args, **kwargs) + + def register_fake(self, fn: Callable) -> Callable: + return fn + + +def torch_custom_op( + name: str, + fn: Callable | None = None, + /, + *, + mutates_args: str | Iterable[str], + device_types: torch.device = None, + schema: str | None = None, +) -> Callable | _IdentityOp: + if _IS_CUSTOM_OP_IN_PYTORCH: + op = custom_op(name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema) + else: + op = _IdentityOp if fn is None else _IdentityOp(fn) + + return op \ No newline at end of file diff --git a/tests/test_mlp.py b/tests/test_mlp.py index e781b02..bc7efcd 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -58,8 +58,14 @@ def test_mlp_correctness(self, length, x_dim, h_dim, E, k, dtype): err_dW1 = torch.abs(dW1_ - dW1) err_dW2 = torch.abs(dW2_ - dW2) tolerance = 1e-2 + print("Y error:", err_Y.max()) + print("dg:", err_dg.max()) + print("dW1:", err_dW1.max()) + print("dW2:", err_dW2.max()) + print("dX:", err_dX.max()) assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() - assert err_dX.max() < tolerance, "dX error too large: max %0.05f" % err_dX.max() assert err_dg.max() < tolerance, "dg error too large: max %0.05f" % err_dg.max() assert err_dW1.max() < tolerance, "dW1 error too large: max %0.05f" % err_dW1.max() assert err_dW2.max() < tolerance, "dW2 error too large: max %0.05f" % err_dW2.max() + assert err_dX.max() < tolerance, "dX error too large: max %0.05f" % err_dX.max() +