From 3668251bf7f434b214380ad00ea9b5c06fe53ab1 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 20 Sep 2024 18:30:19 -0400 Subject: [PATCH 01/24] merging kernel --- .../sparse_utils/sparse_merge/mege.py | 93 +++++++++ .../sparse_utils/sparse_merge/merge.cu | 191 ++++++++++++++++++ 2 files changed, 284 insertions(+) create mode 100644 mttl/models/modifiers/sparse_utils/sparse_merge/mege.py create mode 100644 mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu diff --git a/mttl/models/modifiers/sparse_utils/sparse_merge/mege.py b/mttl/models/modifiers/sparse_utils/sparse_merge/mege.py new file mode 100644 index 000000000..64d5e70a2 --- /dev/null +++ b/mttl/models/modifiers/sparse_utils/sparse_merge/mege.py @@ -0,0 +1,93 @@ +from pathlib import Path + +import torch +from torch.utils.cpp_extension import load_inline + +WMMA_M = 16 +WMMA_N = 16 +WMMA_K = 16 + +def compile_extension(): + cuda_source = Path("/home/mila/o/ostapeno/dev/mttl_public/mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu").read_text() + cpp_source = "torch::Tensor bmm_w_merge(torch::Tensor A, torch::Tensor B, torch::Tensor W, torch::Tensor SPA);" + + + # Load the CUDA kernel as a PyTorch extension + ext = load_inline( + name="bmm_w_merge_wmma_module", + cpp_sources=cpp_source, + cuda_sources=cuda_source, + functions=["bmm_w_merge"], + with_cuda=True, + extra_cuda_cflags=[ + '-gencode', 'arch=compute_80,code=sm_80', # Adjust for A100 GPU + '--expt-relaxed-constexpr', + '--expt-extended-lambda', "-DTORCH_USE_CUDA_DSA" + ], + build_directory='/home/mila/o/ostapeno/dev/mttl_public/mttl/models/modifiers/sparse_utils/sparse_merge/build', + ) + return ext + +# Load the compiled module +batched_matmul_wmma_merge_module = compile_extension() + +def batched_merge_matmul(A, B, W, Adapters): + assert A.dtype == torch.half and B.dtype == torch.half + assert A.is_cuda and B.is_cuda + + batch_size, M, K = A.shape + _, Kb, N = B.shape + assert K == Kb + res = batched_matmul_wmma_merge_module.bmm_w_merge(A, B, W, Adapters) + return res + +def manual_merge_matmul(A, B, W, Adapters): + merged_adapters = torch.einsum("bk,knm->bnm", W, Adapters) + out = torch.bmm(A, B + merged_adapters) + return out + + +# Set the seed for reproducibility +torch.manual_seed(0) +# Define matrix dimensions +batch_size = 128 +M = 256 +K = 256 +N = 256 +E = 10 + +# Create random half-precision tensors +A = torch.randn(batch_size, M, K, device='cuda', dtype=torch.half).contiguous() +B = torch.randn(batch_size, K, N, device='cuda', dtype=torch.half).contiguous() +#random between 0 and 1 +W = torch.randn(batch_size, E, device='cuda', dtype=torch.half).contiguous() +Adapters = torch.randn(E, K, N, device='cuda', dtype=torch.half).contiguous() + +# batched_matmul_wmma_python(A, B) + +# Warm-up +for _ in range(1): + C_custom = batched_merge_matmul(A, B, W, Adapters) + +# Measure performance +import time + +start = time.time() +C_custom = batched_merge_matmul(A, B, W, Adapters) +torch.cuda.synchronize() +end = time.time() +print(f"Custom kernel time: {end - start:.6f} seconds") + +start = time.time() +C_reference = manual_merge_matmul(A, B, W, Adapters) +torch.cuda.synchronize() +end = time.time() +print(f"PyTorch bmm time: {end - start:.6f} seconds") + +C_custom = C_custom.to(C_reference.dtype) +# Verify correctness +max_error = (C_custom - C_reference).abs().max().item() +print(f"Max error: {max_error}") +# C_reference = C_reference.to(C_custom.dtype) +# C_custom = C_custom.to(C_reference.dtype) +print(torch.allclose(C_custom, C_reference, atol=3e-1)) \ No newline at end of file diff --git a/mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu b/mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu new file mode 100644 index 000000000..8eb107ac3 --- /dev/null +++ b/mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu @@ -0,0 +1,191 @@ +#include +#include +#include +#include +#include + +using namespace nvcuda; + +// Constants for WMMA +#define WMMA_M 16 +#define WMMA_N 16 +#define WMMA_K 16 +#define WARP_SIZE 32 + +// Warps per block in M and N dimensions +#define WARPS_PER_BLOCK_M 2 +#define WARPS_PER_BLOCK_N 2 +#define WARPS_PER_BLOCK (WARPS_PER_BLOCK_M * WARPS_PER_BLOCK_N) + +// Assuming maximum E value (number of matrices in M) +#define MAX_E 64 // Adjust based on your maximum E + +// Declare M in constant memory (assuming E is not too large) +// __constant__ __half M_const[MAX_E][WMMA_K][WMMA_N]; // constant memory is like 64KB, so we can keep more than 100 expertparts easily + + +__global__ void batched_matmul_merge_fused_kernel(const __half *__restrict__ A, + const __half *__restrict__ B, + const __half *__restrict__ W, + const __half *__restrict__ SPA, + float *__restrict__ C, + int M, int N, int K, int batch_size, int E) +{ + // Batch index + int batch = blockIdx.z; + // Thread warp index within the thread block + int warp_id = threadIdx.x / WARP_SIZE; // warp withing block + // int lane_id = threadIdx.x % WARP_SIZE; // thread within warp + // blcoksize is 128, so 4 warps per block + int warp_row = warp_id / WARPS_PER_BLOCK_N; // Integer division + int warp_col = warp_id % WARPS_PER_BLOCK_N; + + // Compute the tile indices + // determine the row and column indices of the tile in the output matrix C that a particular warp will compute. + int tile_row = blockIdx.y * WARPS_PER_BLOCK_M + warp_row; + int tile_col = blockIdx.x * WARPS_PER_BLOCK_N + warp_col; + + // Compute the starting row and column indices of the tile + int row = tile_row * WMMA_M; + int col = tile_col * WMMA_N; + + if (row < M && col < N) + { + + // Declare the accumulator fragment + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + // Load per-batch weights into registers + __half W_reg[MAX_E]; + for (int e = 0; e < E; ++e) + { + W_reg[e] = W[batch * E + e]; + } + + // Loop over K dimension + for (int k = 0; k < K; k += WMMA_K) + { + // Declare and initialize fragment for A matrix + wmma::fragment a_frag; + const __half *a_ptr = A + batch * M * K + row * K + k; + + // Declare fragment for B matrix + wmma::fragment B_frag; + + // Compute global memory address for B + const __half *B_ptr = B + batch * K * N + k * N + col; + + // Bounds checking for K dimension + if (k + WMMA_K <= K) + {// Load the A matrix + wmma::load_matrix_sync(a_frag, a_ptr, K); + + // Load the B matrix + wmma::load_matrix_sync(B_frag, B_ptr, N); + + // Compute merged_M_frag for the current K segment + wmma::fragment merged_M_frag; + wmma::fill_fragment(merged_M_frag, __float2half(0.0f)); + + // Accumulate weighted sum into merged_M_frag + for (int e = 0; e < E; ++e) + { + // Load M[e] tile from constant memory into a fragment + wmma::fragment M_frag; + const __half *M_ptr = SPA + (e * K * N) + (k * N + col); + + // Load the M[e] tile + wmma::load_matrix_sync(M_frag, M_ptr, N); + + // Multiply M_frag by W_reg[e] and accumulate + for (int i = 0; i < M_frag.num_elements; ++i) + { + M_frag.x[i] = __hmul(M_frag.x[i], W_reg[e]); + } + + // Accumulate into merged_M_frag + for (int i = 0; i < merged_M_frag.num_elements; ++i) + { + merged_M_frag.x[i] = __hadd(merged_M_frag.x[i], M_frag.x[i]); + } + } + + // Add merged_M_frag to B_frag + for (int i = 0; i < B_frag.num_elements; ++i) + { + B_frag.x[i] = __hadd(B_frag.x[i], merged_M_frag.x[i]); + } + + // Perform the matrix multiplication + wmma::mma_sync(c_frag, a_frag, B_frag, c_frag); + } + else + { + assert(false && "K must be a multiple of WMMA_K"); + } + } + + // Store the output + float *c_ptr = C + batch * M * N + row * N + col; + wmma::store_matrix_sync(c_ptr, c_frag, N, wmma::mem_row_major); + } +} + +torch::Tensor bmm_w_merge(torch::Tensor A, torch::Tensor B, torch::Tensor W, torch::Tensor SPA) +{ + const int batch_size = A.size(0); + const int M = A.size(1); + const int K = A.size(2); + const int N = B.size(2); + const int E = SPA.size(0); + + // Ensure the inputs are in half precision + TORCH_CHECK(A.scalar_type() == at::kHalf, "A must be half-precision"); + TORCH_CHECK(B.scalar_type() == at::kHalf, "B must be half-precision"); + TORCH_CHECK(W.scalar_type() == at::kHalf, "W must be half-precision"); + TORCH_CHECK(SPA.scalar_type() == at::kHalf, "SPA must be half-precision"); + + size_t M_size = E * WMMA_K * WMMA_N * sizeof(__half); + TORCH_CHECK(M_size <= 64 * 1024, "M tensor size exceeds constant memory capacity"); + + TORCH_CHECK(K % WMMA_K == 0, "K must be a multiple of WMMA_K (16)"); + // cudaMemcpyToSymbol(M_const, SPA.data_ptr(), M_size); // constant memory is like 64KB, so we can keep limited number of expertperts + + auto C = torch::empty({batch_size, M, N}, torch::TensorOptions().dtype(torch::kFloat32).device(A.device())); + + // idea: + // each warp computes one tile! + // since each block has 4 tiles, each block computes 4 tiles, 2 in each direction + // Calculate the number of tiles + int M_TILES = (M + WMMA_M - 1) / WMMA_M; + int N_TILES = (N + WMMA_N - 1) / WMMA_N; + + // Calculate grid dimensions + int gridDimY = (M_TILES + WARPS_PER_BLOCK_M - 1) / WARPS_PER_BLOCK_M; // so this will be 8 if we have 16 tiles + + int gridDimX = (N_TILES + WARPS_PER_BLOCK_N - 1) / WARPS_PER_BLOCK_N; // this will be also 8 if we have 16 tiles + + dim3 threads(WARP_SIZE * WARPS_PER_BLOCK); + dim3 blocks(gridDimX, gridDimY, batch_size); + + // Launch the CUDA kernel + batched_matmul_merge_fused_kernel<<>>( + reinterpret_cast(A.data_ptr()), + reinterpret_cast(B.data_ptr()), + reinterpret_cast(W.data_ptr()), + reinterpret_cast(SPA.data_ptr()), + C.data_ptr(), + M, N, K, batch_size, E); + + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + std::stringstream ss; + ss << "CUDA kernel failed: " << cudaGetErrorString(err); + throw std::runtime_error(ss.str()); + } + + return C; +} \ No newline at end of file From 766453b26badb9a4c1c6d4547108e26eb5adc724 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 25 Sep 2024 16:00:00 -0400 Subject: [PATCH 02/24] profile_mask_merging --- .../sparse_utils/profile_adapter_merging.py | 679 ++++++++++++++++++ 1 file changed, 679 insertions(+) create mode 100644 mttl/models/modifiers/sparse_utils/profile_adapter_merging.py diff --git a/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py b/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py new file mode 100644 index 000000000..7084f83cf --- /dev/null +++ b/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py @@ -0,0 +1,679 @@ +import logging +import re +import time +from typing import List + +import numpy as np +import pandas as pd +import stk.ops +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton as tn +from pytorch_lightning import seed_everything +from spops import csr_add, spmm +from triton.ops.blocksparse import matmul + +from mttl.logging import logger +from mttl.models.modifiers import modify_transformer +from mttl.models.modifiers.base import Modifier +from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig +from mttl.models.modifiers.sparse_mask import ( + MaskedLinear, + ScatteredSparseLinearModule, + SparseLinearModule, + SparseMaskAdapter, + SparseMaskConfig, +) +from mttl.models.utils import model_loader_helper, transfer_batch_to_device + +device = "cuda" +logger.setLevel(logging.ERROR) +block_size = 128 +n_blocks = 16 + +in_d = 2048 +out_d = 8192 +dtype = torch.bfloat16 + +# input sizes and batch sizes for testing +max_seq_len = 1024 +bs = 5 + + +layer = nn.Linear(in_d, out_d).to(device) +layer.weight.requires_grad_(False) +layer.bias.requires_grad_(False) +K = 10 + + +def calculate_lora_parameters(input_dim, output_dim, rank): + return input_dim * rank + output_dim * rank + + +def find_hyperpaams(): + modules = {"linear": layer} + modified_modules = {} + keep_ratios = [] + lora_ranks = [] + + for name, module in modules.items(): + keep_ratio = ( + n_blocks * (block_size**2) / (module.in_features * module.out_features) + ) + tot_sparse_params = module.in_features * module.out_features * keep_ratio + lora_rank = 1 + for rank in range(1, module.in_features): + lora_params = calculate_lora_parameters( + module.in_features, module.out_features, rank + ) + if lora_params <= tot_sparse_params: + lora_rank = rank + else: + break + modified_modules[name] = { + "module": module, + "keep_ratio": keep_ratio, + "lora_rank": lora_rank, + } + keep_ratios.append(keep_ratio) + lora_ranks.append(lora_rank) + return np.mean(keep_ratios), int(np.mean(lora_ranks)) + + +keep_ratio, lora_rank = find_hyperpaams() +print(f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}") +x = torch.randn(bs, max_seq_len, in_d, dtype=dtype, device=device) + + +def create_adapter_set(adapter_config, layer, K) -> List[Modifier]: + if isinstance(adapter_config, SparseMaskConfig): + layer = nn.Linear(out_d, in_d) # TODO: implement transpose in SparseWeights + module = [SparseMaskAdapter(adapter_config, layer) for _ in range(K)] + elif isinstance(adapter_config, LoRAConfig): + module = [LoRA(adapter_config, layer) for _ in range(K)] + return module + + +@torch.autocast(device_type="cuda", dtype=dtype) +def lora_merge(lora_a, lora_b, x, W_base, W_merge): + + # merge into 1 loa + A = torch.einsum("ble,edr->bldr", (W_merge, lora_a)) + B = torch.einsum("ble,erd->blrd", (W_merge, lora_b)) + # lora forward + partial_out = torch.einsum("bld,bldr->blr", (x, A)) + adapter_out = torch.einsum("blr,blrd->bld", (partial_out, B)) + dense_out = x @ W_base + return adapter_out + dense_out + + +@torch.autocast(device_type="cuda", dtype=dtype) +def sparse_merge_and_forward(sparse_weights, x, W_base, W_merge): + """ + Perform the merging of sparse adapters and compute the forward pass. This uses torch dds mm. + + Parameters: + - sparse_weights: List[torch.Tensor], each of shape [input_dim, output_dim] in CSR format. + - x: torch.Tensor, input of shape [bs, max_seq_len, input_dim]. + - W_base: torch.Tensor, base model weights of shape [input_dim, output_dim]. + - W_merge: torch.Tensor, merging weights of shape [bs, max_seq_len, K]. + + Returns: + - y: torch.Tensor, output of shape [bs, max_seq_len, output_dim]. + """ + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + + # Compute the output for this adapter + output_k = ( + x_flat @ S_k + ) # Shape: [bs * max_seq_len, output_dim] <- this is dds mm + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def blck_sparse_merge_and_forward(sparse_weights, x, W_base, W_merge): + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + + # Compute the output for this adapter + output_k = F.linear(x_flat, S_k) # Shape: [bs * max_seq_len, output_dim] + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def sparse_merge_and_forward_with_SpMM(sparse_weights, x, W_base, W_merge): + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + + # Compute the output for this adapter + output_k = spmm( + S_k.sparse_weights, + S_k.row_offs, + S_k.row_idx, + S_k.col_idx, + x_flat.T.contiguous(), + S_k.shape[0], + backend="sputnik", + ) # Shape: [bs * max_seq_len, output_dim] + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k.T * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def sparse_merge_and_forward_with_spadd(sparse_weights, x, W_base, W_merge): + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_w = torch.zeros(sparse_weights[0].shape).to(device) + + # Iterate over each adapter + for k in range(K): + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + S_k = sparse_weights[k] + + # Compute the output for this adapter + adapter_w = csr_add( + S_k.sparse_weights * W_k, S_k.row_offs, S_k.row_idx, S_k.col_idx, adapter_w + ) + + adapter_out = x_flat @ adapter_w + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def sparse_merge_and_forward_vectorized(sparse_weights, x, W_base, W_merge): + """ + Perform the merging of sparse adapters and compute the forward pass. + + Parameters: + - sparse_weights: List[torch.Tensor], each of shape [input_dim, output_dim] in CSR format. + - x: torch.Tensor, input of shape [bs, max_seq_len, input_dim]. + - W_base: torch.Tensor, base model weights of shape [input_dim, output_dim]. + - W_merge: torch.Tensor, merging weights of shape [bs, max_seq_len, K]. + + Returns: + - y: torch.Tensor, output of shape [bs, max_seq_len, output_dim]. + """ + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Stack and expand sparse weights + # Convert sparse weights to dense (if memory allows) + sparse_weights_dense = torch.stack( + [S.to_dense() for S in sparse_weights], dim=0 + ) # [K, input_dim, output_dim] + + # Compute adapter outputs + # [bs*max_seq_len, K, output_dim] + adapter_out = torch.einsum("bi,kio->bko", x_flat, sparse_weights_dense) + W_merge_flat = W_merge.reshape(bs * max_seq_len, K, 1) # [bs*max_seq_len, K, 1] + adapter_out = (adapter_out * W_merge_flat).sum( + dim=1 + ) # [bs*max_seq_len, output_dim] + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def blk_sparse_merge_and_forward_triton( + block_sparse_ops, block_sparse_weights, x, W_base, W_merge +): + """ + Perform the merging of sparse adapters and compute the forward pass. This uses triton dds kernel with precomputed layour (see prepare_triton_bs_op). + + Parameters: + - block_sparse_ops: List[triton.ops.blocksparse.matmul], each of shape [input_dim, output_dim] in CSR format. + - block_sparse_weights: List[torch.Tensor], each of shape [input_dim, output_dim] in BSR format (these are only non-zero blocks). + - x: torch.Tensor, input of shape [bs, max_seq_len, input_dim]. + - W_base: torch.Tensor, base model weights of shape [input_dim, output_dim]. + - W_merge: torch.Tensor, merging weights of shape [bs, max_seq_len, K]. + + Returns: + - y: torch.Tensor, output of shape [bs, max_seq_len, output_dim]. + """ + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = block_sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + _x_flat = x_flat.unsqueeze(0).unsqueeze(0).contiguous() + # Compute the output for this adapter + output_k = block_sparse_ops[k]( + _x_flat, S_k + ).squeeze() # Shape: [bs * max_seq_len, output_dim] + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def blk_sparse_merge_and_forward_stk(block_sparse_weights, x, W_base, W_merge): + """ + Perform the merging of sparse adapters and compute the forward pass. This uses triton dds kernel with precomputed layour (see prepare_triton_bs_op). + + Parameters: + - block_sparse_ops: List[triton.ops.blocksparse.matmul], each of shape [input_dim, output_dim] in CSR format. + - block_sparse_weights: List[torch.Tensor], each of shape [input_dim, output_dim] in BSR format (these are only non-zero blocks). + - x: torch.Tensor, input of shape [bs, max_seq_len, input_dim]. + - W_base: torch.Tensor, base model weights of shape [input_dim, output_dim]. + - W_merge: torch.Tensor, merging weights of shape [bs, max_seq_len, K]. + + Returns: + - y: torch.Tensor, output of shape [bs, max_seq_len, output_dim]. + """ + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = block_sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + # Compute the output for this adapter + output_k = stk.ops.dds(x_flat, S_k) # Shape: [bs * max_seq_len, output_dim] + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +sparse_merge_and_forward_compiled = torch.compile(sparse_merge_and_forward) +lora_merge_compiled = torch.compile(lora_merge) + +# adapter_config_sm = SparseMaskConfig( +# sps_impl="scattered", +# sps_type="regular_sparse", +# keep_ratio=keep_ratio, +# reselection_steps=1, +# block_size=block_size, +# ) + + +adapter_config_lora = LoRAConfig(modify_layers="", lora_rank=lora_rank) +adapter_config_bs = SparseMaskConfig( + sps_impl="scattered", + sps_type="block_sparse", + keep_ratio=keep_ratio, + reselection_steps=1, + block_size=block_size, +) + + +def bsr_to_binary_layout(bsr_matrix, block_size): + # Get the shape of the BSR matrix + M, K = bsr_matrix.shape + + # Number of blocks along rows and columns + num_block_rows = M // block_size + num_block_cols = K // block_size + + # Initialize the binary layout matrix with zeros + binary_layout = torch.zeros((num_block_rows, num_block_cols), dtype=int) + + # Get BSR matrix data + block_row_indices = bsr_matrix.col_indices() + block_row_pointers = bsr_matrix.crow_indices() + + # Iterate over the block rows + for block_row in range(num_block_rows): + # Iterate over the non-zero blocks in the current block row + for idx in range( + block_row_pointers[block_row], block_row_pointers[block_row + 1] + ): + block_col = block_row_indices[idx] + # Mark the block as non-zero + binary_layout[block_row, block_col] = 1 + + return binary_layout + + +def prepare_triton_bs_op(W, op_mode): + Z, H = 1, 1 + AT = False + BT = False + + layout = bsr_to_binary_layout(W, block_size).unsqueeze(0) + # creat inputs + op = matmul(layout, block_size, op_mode, trans_a=AT, trans_b=BT, device="cuda") + return op + + +@tn.testing.perf_report( + tn.testing.Benchmark( + x_names=["K"], # Argument names to use as an x-axis for the plot. + x_vals=[2, 3, 4], # Different possible values for `x_name`. + x_log=False, # x axis is logarithmic. + line_arg="provider", # Argument name whose value corresponds to a different line in the plot. + line_vals=[ + "stk", + "triton_blck_sparse", + "lora", + "torch_sparse", + "torch_block_sparse", + ], # "lora_compiled", "torch_sparse_compiled"], # Possible values for `line_arg`. + line_names=[ + "stk", + "triton_blck_sparse", + "lora", + "torch_sparse", + "torch_block_sparse", + ], # "lora_compiled", "torch_sparse_compiled"], # Label name for the lines. + styles=[ + ("blue", "-"), + ("green", "-"), + ("orange", "-"), + ("red", "-"), + ("purple", "-"), + ("black", "-"), + ("brown", "-"), + ], # Line color and style. + ylabel="ms", #'GB/s', # Label name for the y-axis. + xlabel="K", + plot_name="matmul-performance", # Name for the plot. Used also as a file name for saving the plot. + args={"bs": bs, "max_seq_len": max_seq_len, "in_d": in_d, "d_out": out_d}, + ) +) +def benchmark(K, bs, max_seq_len, in_d, d_out, provider): + W_mege = torch.randn(bs, max_seq_len, K, dtype=dtype, device=device) + loras = create_adapter_set(adapter_config_lora, layer, K) + sparse_modules = create_adapter_set(adapter_config_bs, layer, K) + W_mege = W_mege.to(dtype=loras[0].lora_a.dtype) + + lora_a = torch.stack([lora.lora_a for lora in loras], dim=0) + lora_b = torch.stack([lora.lora_b for lora in loras], dim=0) + sparse_weights: List[torch.Tensor] = [ + sparse_module.sparse_layer.to_dense().to_sparse_csr().to(device) + for sparse_module in sparse_modules + ] + sparse_weights_spops = [ + sparse_module.sparse_layer.to(device) for sparse_module in sparse_modules + ] + + print("Testing provider:", provider, "K:", K) + quantiles = [0.5, 0.2, 0.8] + if provider == "lora": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: lora_merge(lora_a, lora_b, x, layer.weight.T, W_mege), + quantiles=quantiles, + ) + elif provider == "lora_compiled": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: lora_merge_compiled(lora_a, lora_b, x, layer.weight.T, W_mege), + quantiles=quantiles, + ) + elif provider == "torch_sparse": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: sparse_merge_and_forward(sparse_weights, x, layer.weight.T, W_mege), + quantiles=quantiles, + ) + elif provider == "torch_block_sparse": + block_sparse_weights: List[torch.Tensor] = [ + sparse_module.sparse_layer.to_dense().T.to_sparse_bsr(block_size).to(device) + for sparse_module in sparse_modules + ] + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: blck_sparse_merge_and_forward( + block_sparse_weights, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + elif provider == "torch_sparse_compiled": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: sparse_merge_and_forward_compiled( + sparse_weights, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + elif provider == "sparse_vectorized": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: sparse_merge_and_forward_vectorized( + sparse_weights, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + elif provider == "sparse_spadd": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: sparse_merge_and_forward_with_spadd( + sparse_weights_spops, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + elif provider == "triton_blck_sparse": + block_sparse_weights: List[torch.Tensor] = [ + sparse_module.sparse_layer.to_dense().to_sparse_bsr(block_size).to(device) + for sparse_module in sparse_modules + ] + # create a list of ops with precomputed layouts for the BSR matrices + block_sparse_ops = [ + prepare_triton_bs_op(sparse_w, "dds") for sparse_w in block_sparse_weights + ] + # block_sparse_weights_as_dense = [ + # sparse_w.to_dense() + # .to(dtype) + # .reshape(-1, block_size, block_size) + # .unsqueeze(0) + # .contiguous() + # for sparse_w in block_sparse_weights + # ] + block_sparse_weights = [ + sparse_w.values().to(dtype).unsqueeze(0).contiguous() + for sparse_w in block_sparse_weights + ] + + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: blk_sparse_merge_and_forward_triton( + block_sparse_ops, + block_sparse_weights, + x, + layer.weight.T, + W_mege, + ), + quantiles=quantiles, + ) + + elif provider == "stk": + # only supports block_size = 128 and float16 + if block_size != 128: + ms, min_ms, max_ms = 0, 0, 0 + else: + block_sparse_weights = [] + for sparse_module in sparse_modules: + W = sparse_module.sparse_layer.to_dense().to(device).to(torch.float16) + W_stk = stk.ops.to_sparse(W, blocking=block_size) + W_stk.validate() + block_sparse_weights.append(W_stk) + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: blk_sparse_merge_and_forward_stk( + block_sparse_weights, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + + # gbps = lambda ms: 2 * s * h * o * 2 * 1e-9 / (ms * 1e-3) + # return gbps(ms), gbps(max_ms), gbps(min_ms) + return ms, max_ms, min_ms + + +benchmark.run(show_plots=True, print_data=True, save_path=".") From 2dd98aba523640ba81df0e6ddc4230fa61a298ff Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 25 Sep 2024 16:15:23 -0400 Subject: [PATCH 03/24] profile adapter merging --- .../sparse_utils/profile_adapter_merging.py | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py b/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py index 7084f83cf..d7bd3d3d1 100644 --- a/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py +++ b/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py @@ -29,8 +29,8 @@ device = "cuda" logger.setLevel(logging.ERROR) -block_size = 128 -n_blocks = 16 +block_size = 128 # 16 +n_blocks = 16 # 1024 in_d = 2048 out_d = 8192 @@ -82,7 +82,9 @@ def find_hyperpaams(): keep_ratio, lora_rank = find_hyperpaams() -print(f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}") +print( + f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}, Lora params: {calculate_lora_parameters(in_d, out_d, lora_rank)}, Sparse params: {in_d * out_d * keep_ratio}" +) x = torch.randn(bs, max_seq_len, in_d, dtype=dtype, device=device) @@ -407,19 +409,6 @@ def blk_sparse_merge_and_forward_triton( @torch.autocast(device_type="cuda", dtype=dtype) def blk_sparse_merge_and_forward_stk(block_sparse_weights, x, W_base, W_merge): - """ - Perform the merging of sparse adapters and compute the forward pass. This uses triton dds kernel with precomputed layour (see prepare_triton_bs_op). - - Parameters: - - block_sparse_ops: List[triton.ops.blocksparse.matmul], each of shape [input_dim, output_dim] in CSR format. - - block_sparse_weights: List[torch.Tensor], each of shape [input_dim, output_dim] in BSR format (these are only non-zero blocks). - - x: torch.Tensor, input of shape [bs, max_seq_len, input_dim]. - - W_base: torch.Tensor, base model weights of shape [input_dim, output_dim]. - - W_merge: torch.Tensor, merging weights of shape [bs, max_seq_len, K]. - - Returns: - - y: torch.Tensor, output of shape [bs, max_seq_len, output_dim]. - """ bs, max_seq_len, input_dim = x.shape output_dim = W_base.shape[1] K = W_merge.shape[2] From dad6f1922b34de71bd5c729a58804f8db7d40a99 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 27 Sep 2024 09:45:31 -0400 Subject: [PATCH 04/24] sparse merging and sdd kernel --- mttl/models/modifiers/__init__.py | 2 +- .../sparse_utils/sparse_merge/mege.py | 93 ----- .../sparse_utils/sparse_merge/merge.cu | 191 --------- mttl/models/modifiers/spasity/__init__.py | 3 + mttl/models/modifiers/spasity/matrix.py | 364 ++++++++++++++++++ .../modifiers/{ => spasity}/sparse_mask.py | 27 +- .../sparse_utils/csr_add_vs_scatter_add.py | 2 +- .../sparse_utils/profile_adapter_merging.py | 8 +- .../sparse_utils/profile_block_sparsity.py | 6 +- .../sparse_utils/profile_sparse_mask.py | 2 +- .../profile_sparse_mask_only_linear.py | 2 +- .../{ => spasity}/sparse_utils/utils.py | 4 +- mttl/models/modifiers/spasity/stk/__init__.py | 1 + mttl/models/modifiers/spasity/stk/autocast.py | 40 ++ .../models/modifiers/spasity/stk/functions.py | 54 +++ .../modifiers/spasity/stk/linear_ops.py | 45 +++ .../modifiers/spasity/stk/linear_ops_test.py | 146 +++++++ .../modifiers/spasity/stk/matrix_ops.py | 195 ++++++++++ .../modifiers/spasity/stk/matrix_ops_test.py | 64 +++ .../modifiers/spasity/stk/random_ops.py | 37 ++ .../modifiers/spasity/stk/triton_kernels.py | 197 ++++++++++ tests/test_sparse_masks.py | 2 +- 22 files changed, 1185 insertions(+), 300 deletions(-) delete mode 100644 mttl/models/modifiers/sparse_utils/sparse_merge/mege.py delete mode 100644 mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu create mode 100644 mttl/models/modifiers/spasity/__init__.py create mode 100644 mttl/models/modifiers/spasity/matrix.py rename mttl/models/modifiers/{ => spasity}/sparse_mask.py (97%) rename mttl/models/modifiers/{ => spasity}/sparse_utils/csr_add_vs_scatter_add.py (97%) rename mttl/models/modifiers/{ => spasity}/sparse_utils/profile_adapter_merging.py (99%) rename mttl/models/modifiers/{ => spasity}/sparse_utils/profile_block_sparsity.py (96%) rename mttl/models/modifiers/{ => spasity}/sparse_utils/profile_sparse_mask.py (99%) rename mttl/models/modifiers/{ => spasity}/sparse_utils/profile_sparse_mask_only_linear.py (99%) rename mttl/models/modifiers/{ => spasity}/sparse_utils/utils.py (99%) create mode 100644 mttl/models/modifiers/spasity/stk/__init__.py create mode 100644 mttl/models/modifiers/spasity/stk/autocast.py create mode 100644 mttl/models/modifiers/spasity/stk/functions.py create mode 100644 mttl/models/modifiers/spasity/stk/linear_ops.py create mode 100644 mttl/models/modifiers/spasity/stk/linear_ops_test.py create mode 100644 mttl/models/modifiers/spasity/stk/matrix_ops.py create mode 100644 mttl/models/modifiers/spasity/stk/matrix_ops_test.py create mode 100644 mttl/models/modifiers/spasity/stk/random_ops.py create mode 100644 mttl/models/modifiers/spasity/stk/triton_kernels.py diff --git a/mttl/models/modifiers/__init__.py b/mttl/models/modifiers/__init__.py index 49733028c..63de57288 100644 --- a/mttl/models/modifiers/__init__.py +++ b/mttl/models/modifiers/__init__.py @@ -6,4 +6,4 @@ import mttl.models.modifiers.lora # noqa: F401 import mttl.models.modifiers.mlp # noqa: F401 import mttl.models.modifiers.prompt_tuning # noqa: F401 -import mttl.models.modifiers.sparse_mask # noqa: F401 +import mttl.models.modifiers.spasity.sparse_mask # noqa: F401 diff --git a/mttl/models/modifiers/sparse_utils/sparse_merge/mege.py b/mttl/models/modifiers/sparse_utils/sparse_merge/mege.py deleted file mode 100644 index 64d5e70a2..000000000 --- a/mttl/models/modifiers/sparse_utils/sparse_merge/mege.py +++ /dev/null @@ -1,93 +0,0 @@ -from pathlib import Path - -import torch -from torch.utils.cpp_extension import load_inline - -WMMA_M = 16 -WMMA_N = 16 -WMMA_K = 16 - -def compile_extension(): - cuda_source = Path("/home/mila/o/ostapeno/dev/mttl_public/mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu").read_text() - cpp_source = "torch::Tensor bmm_w_merge(torch::Tensor A, torch::Tensor B, torch::Tensor W, torch::Tensor SPA);" - - - # Load the CUDA kernel as a PyTorch extension - ext = load_inline( - name="bmm_w_merge_wmma_module", - cpp_sources=cpp_source, - cuda_sources=cuda_source, - functions=["bmm_w_merge"], - with_cuda=True, - extra_cuda_cflags=[ - '-gencode', 'arch=compute_80,code=sm_80', # Adjust for A100 GPU - '--expt-relaxed-constexpr', - '--expt-extended-lambda', "-DTORCH_USE_CUDA_DSA" - ], - build_directory='/home/mila/o/ostapeno/dev/mttl_public/mttl/models/modifiers/sparse_utils/sparse_merge/build', - ) - return ext - -# Load the compiled module -batched_matmul_wmma_merge_module = compile_extension() - -def batched_merge_matmul(A, B, W, Adapters): - assert A.dtype == torch.half and B.dtype == torch.half - assert A.is_cuda and B.is_cuda - - batch_size, M, K = A.shape - _, Kb, N = B.shape - assert K == Kb - res = batched_matmul_wmma_merge_module.bmm_w_merge(A, B, W, Adapters) - return res - -def manual_merge_matmul(A, B, W, Adapters): - merged_adapters = torch.einsum("bk,knm->bnm", W, Adapters) - out = torch.bmm(A, B + merged_adapters) - return out - - -# Set the seed for reproducibility -torch.manual_seed(0) -# Define matrix dimensions -batch_size = 128 -M = 256 -K = 256 -N = 256 -E = 10 - -# Create random half-precision tensors -A = torch.randn(batch_size, M, K, device='cuda', dtype=torch.half).contiguous() -B = torch.randn(batch_size, K, N, device='cuda', dtype=torch.half).contiguous() -#random between 0 and 1 -W = torch.randn(batch_size, E, device='cuda', dtype=torch.half).contiguous() -Adapters = torch.randn(E, K, N, device='cuda', dtype=torch.half).contiguous() - -# batched_matmul_wmma_python(A, B) - -# Warm-up -for _ in range(1): - C_custom = batched_merge_matmul(A, B, W, Adapters) - -# Measure performance -import time - -start = time.time() -C_custom = batched_merge_matmul(A, B, W, Adapters) -torch.cuda.synchronize() -end = time.time() -print(f"Custom kernel time: {end - start:.6f} seconds") - -start = time.time() -C_reference = manual_merge_matmul(A, B, W, Adapters) -torch.cuda.synchronize() -end = time.time() -print(f"PyTorch bmm time: {end - start:.6f} seconds") - -C_custom = C_custom.to(C_reference.dtype) -# Verify correctness -max_error = (C_custom - C_reference).abs().max().item() -print(f"Max error: {max_error}") -# C_reference = C_reference.to(C_custom.dtype) -# C_custom = C_custom.to(C_reference.dtype) -print(torch.allclose(C_custom, C_reference, atol=3e-1)) \ No newline at end of file diff --git a/mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu b/mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu deleted file mode 100644 index 8eb107ac3..000000000 --- a/mttl/models/modifiers/sparse_utils/sparse_merge/merge.cu +++ /dev/null @@ -1,191 +0,0 @@ -#include -#include -#include -#include -#include - -using namespace nvcuda; - -// Constants for WMMA -#define WMMA_M 16 -#define WMMA_N 16 -#define WMMA_K 16 -#define WARP_SIZE 32 - -// Warps per block in M and N dimensions -#define WARPS_PER_BLOCK_M 2 -#define WARPS_PER_BLOCK_N 2 -#define WARPS_PER_BLOCK (WARPS_PER_BLOCK_M * WARPS_PER_BLOCK_N) - -// Assuming maximum E value (number of matrices in M) -#define MAX_E 64 // Adjust based on your maximum E - -// Declare M in constant memory (assuming E is not too large) -// __constant__ __half M_const[MAX_E][WMMA_K][WMMA_N]; // constant memory is like 64KB, so we can keep more than 100 expertparts easily - - -__global__ void batched_matmul_merge_fused_kernel(const __half *__restrict__ A, - const __half *__restrict__ B, - const __half *__restrict__ W, - const __half *__restrict__ SPA, - float *__restrict__ C, - int M, int N, int K, int batch_size, int E) -{ - // Batch index - int batch = blockIdx.z; - // Thread warp index within the thread block - int warp_id = threadIdx.x / WARP_SIZE; // warp withing block - // int lane_id = threadIdx.x % WARP_SIZE; // thread within warp - // blcoksize is 128, so 4 warps per block - int warp_row = warp_id / WARPS_PER_BLOCK_N; // Integer division - int warp_col = warp_id % WARPS_PER_BLOCK_N; - - // Compute the tile indices - // determine the row and column indices of the tile in the output matrix C that a particular warp will compute. - int tile_row = blockIdx.y * WARPS_PER_BLOCK_M + warp_row; - int tile_col = blockIdx.x * WARPS_PER_BLOCK_N + warp_col; - - // Compute the starting row and column indices of the tile - int row = tile_row * WMMA_M; - int col = tile_col * WMMA_N; - - if (row < M && col < N) - { - - // Declare the accumulator fragment - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); - - // Load per-batch weights into registers - __half W_reg[MAX_E]; - for (int e = 0; e < E; ++e) - { - W_reg[e] = W[batch * E + e]; - } - - // Loop over K dimension - for (int k = 0; k < K; k += WMMA_K) - { - // Declare and initialize fragment for A matrix - wmma::fragment a_frag; - const __half *a_ptr = A + batch * M * K + row * K + k; - - // Declare fragment for B matrix - wmma::fragment B_frag; - - // Compute global memory address for B - const __half *B_ptr = B + batch * K * N + k * N + col; - - // Bounds checking for K dimension - if (k + WMMA_K <= K) - {// Load the A matrix - wmma::load_matrix_sync(a_frag, a_ptr, K); - - // Load the B matrix - wmma::load_matrix_sync(B_frag, B_ptr, N); - - // Compute merged_M_frag for the current K segment - wmma::fragment merged_M_frag; - wmma::fill_fragment(merged_M_frag, __float2half(0.0f)); - - // Accumulate weighted sum into merged_M_frag - for (int e = 0; e < E; ++e) - { - // Load M[e] tile from constant memory into a fragment - wmma::fragment M_frag; - const __half *M_ptr = SPA + (e * K * N) + (k * N + col); - - // Load the M[e] tile - wmma::load_matrix_sync(M_frag, M_ptr, N); - - // Multiply M_frag by W_reg[e] and accumulate - for (int i = 0; i < M_frag.num_elements; ++i) - { - M_frag.x[i] = __hmul(M_frag.x[i], W_reg[e]); - } - - // Accumulate into merged_M_frag - for (int i = 0; i < merged_M_frag.num_elements; ++i) - { - merged_M_frag.x[i] = __hadd(merged_M_frag.x[i], M_frag.x[i]); - } - } - - // Add merged_M_frag to B_frag - for (int i = 0; i < B_frag.num_elements; ++i) - { - B_frag.x[i] = __hadd(B_frag.x[i], merged_M_frag.x[i]); - } - - // Perform the matrix multiplication - wmma::mma_sync(c_frag, a_frag, B_frag, c_frag); - } - else - { - assert(false && "K must be a multiple of WMMA_K"); - } - } - - // Store the output - float *c_ptr = C + batch * M * N + row * N + col; - wmma::store_matrix_sync(c_ptr, c_frag, N, wmma::mem_row_major); - } -} - -torch::Tensor bmm_w_merge(torch::Tensor A, torch::Tensor B, torch::Tensor W, torch::Tensor SPA) -{ - const int batch_size = A.size(0); - const int M = A.size(1); - const int K = A.size(2); - const int N = B.size(2); - const int E = SPA.size(0); - - // Ensure the inputs are in half precision - TORCH_CHECK(A.scalar_type() == at::kHalf, "A must be half-precision"); - TORCH_CHECK(B.scalar_type() == at::kHalf, "B must be half-precision"); - TORCH_CHECK(W.scalar_type() == at::kHalf, "W must be half-precision"); - TORCH_CHECK(SPA.scalar_type() == at::kHalf, "SPA must be half-precision"); - - size_t M_size = E * WMMA_K * WMMA_N * sizeof(__half); - TORCH_CHECK(M_size <= 64 * 1024, "M tensor size exceeds constant memory capacity"); - - TORCH_CHECK(K % WMMA_K == 0, "K must be a multiple of WMMA_K (16)"); - // cudaMemcpyToSymbol(M_const, SPA.data_ptr(), M_size); // constant memory is like 64KB, so we can keep limited number of expertperts - - auto C = torch::empty({batch_size, M, N}, torch::TensorOptions().dtype(torch::kFloat32).device(A.device())); - - // idea: - // each warp computes one tile! - // since each block has 4 tiles, each block computes 4 tiles, 2 in each direction - // Calculate the number of tiles - int M_TILES = (M + WMMA_M - 1) / WMMA_M; - int N_TILES = (N + WMMA_N - 1) / WMMA_N; - - // Calculate grid dimensions - int gridDimY = (M_TILES + WARPS_PER_BLOCK_M - 1) / WARPS_PER_BLOCK_M; // so this will be 8 if we have 16 tiles - - int gridDimX = (N_TILES + WARPS_PER_BLOCK_N - 1) / WARPS_PER_BLOCK_N; // this will be also 8 if we have 16 tiles - - dim3 threads(WARP_SIZE * WARPS_PER_BLOCK); - dim3 blocks(gridDimX, gridDimY, batch_size); - - // Launch the CUDA kernel - batched_matmul_merge_fused_kernel<<>>( - reinterpret_cast(A.data_ptr()), - reinterpret_cast(B.data_ptr()), - reinterpret_cast(W.data_ptr()), - reinterpret_cast(SPA.data_ptr()), - C.data_ptr(), - M, N, K, batch_size, E); - - // Check for CUDA errors - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - std::stringstream ss; - ss << "CUDA kernel failed: " << cudaGetErrorString(err); - throw std::runtime_error(ss.str()); - } - - return C; -} \ No newline at end of file diff --git a/mttl/models/modifiers/spasity/__init__.py b/mttl/models/modifiers/spasity/__init__.py new file mode 100644 index 000000000..67fb01f06 --- /dev/null +++ b/mttl/models/modifiers/spasity/__init__.py @@ -0,0 +1,3 @@ +# largely inspired/adopted from STK: https://github.com/stanford-futuredata/stk +from mttl.models.modifiers.spasity.matrix import Matrix +from mttl.models.modifiers.spasity.sparse_mask import * diff --git a/mttl/models/modifiers/spasity/matrix.py b/mttl/models/modifiers/spasity/matrix.py new file mode 100644 index 000000000..dc6fe8be5 --- /dev/null +++ b/mttl/models/modifiers/spasity/matrix.py @@ -0,0 +1,364 @@ +import numpy as np +import torch + +# this is copy paste from stk: https://github.com/stanford-futuredata/stk + + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}" + ) + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape." + ) + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking." + ) + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})" + ) + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices." + ) + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices." + ) + + if offsets.dim() != 1: + raise ValueError(f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks" + ) + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks" + ) + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows." + ) + + is_cuda = ( + data.is_cuda + and row_indices.is_cuda + and column_indices.is_cuda + and offsets.is_cuda + ) + is_cpu = ( + not data.is_cuda + and not row_indices.is_cuda + and not column_indices.is_cuda + and not offsets.is_cuda + ) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}." + ) + + if data.dtype != torch.float16: + raise ValueError(f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices." + ) + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices." + ) + if offsets.dtype != torch.int32: + raise ValueError(f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__( + self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None, + ): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ( + (column_indices_t is None) + or (offsets_t is None) + or (block_offsets_t is None) + ): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets + ) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices." + ) + + def validate(self): + _validate_matrix( + self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets, + ) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone(), + ) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D." + ) + out = Matrix( + self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t, + ) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}." + ) + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}" + ) + return Matrix( + shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t, + ) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix( + size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t, + ) + return out if self.is_contiguous() else out.t() diff --git a/mttl/models/modifiers/sparse_mask.py b/mttl/models/modifiers/spasity/sparse_mask.py similarity index 97% rename from mttl/models/modifiers/sparse_mask.py rename to mttl/models/modifiers/spasity/sparse_mask.py index dd9643a48..53ba15ce5 100644 --- a/mttl/models/modifiers/sparse_mask.py +++ b/mttl/models/modifiers/spasity/sparse_mask.py @@ -4,13 +4,13 @@ import numpy as np import torch -from scipy.sparse import csr_matrix +from scipy.sparse import bsr_matrix, csr_matrix from torch import nn from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig, ModifyMixin -from mttl.models.modifiers.sparse_utils.utils import ( +from mttl.models.modifiers.spasity.sparse_utils.utils import ( BlcokSparseLinearFunction_SP_ADD, BlcokSparseLinearFunction_SP_SCATTER, LinearWithSparseDelta, @@ -135,6 +135,12 @@ def __init__(self, config: SparseMaskConfig, shape, dtype, device, **kwargs): self.set_sparse_idxs(_sparse_csr_representation) + self.init_random() + + def init_random(self): + # init 1D tensor of values + nn.init.normal_(self.sparse_weights, mean=0.0, std=0.1) + @property def device(self): return self.sparse_weights.device @@ -197,7 +203,7 @@ def twoD_indices(self): """ return get_2d_indices_from_csr_matrix(self.scipy_representation) - def to_dense(self): + def to_dense(self) -> torch.Tensor: """ Returns dense representation of the sparse weights. """ @@ -227,10 +233,25 @@ class BlockSparseWeights(SparseWeights): def __init__(self, config: SparseMaskConfig, shape, dtype, device, **kwargs): super().__init__(config, shape, dtype, device, **kwargs) + self.sparse_weights = nn.Parameter( self.sparse_weights.data.view(-1, self.block_size, self.block_size), requires_grad=True, ) + bsr_indices, bsr_indptr = self.get_to_bsr_indices() + self.register_buffer("bsr_indices", bsr_indices) + self.register_buffer("bsr_indptr", bsr_indptr) + + def get_to_bsr_indices(self) -> bsr_matrix: + """ + Returns the sparse weights in BSR format. + """ + bsr = bsr_matrix( + self.scipy_representation, blocksize=(self.block_size, self.block_size) + ) + return torch.tensor(bsr.indices, dtype=torch.int32), torch.tensor( + bsr.indptr, dtype=torch.int32 + ) @property def scipy_representation(self): diff --git a/mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py b/mttl/models/modifiers/spasity/sparse_utils/csr_add_vs_scatter_add.py similarity index 97% rename from mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py rename to mttl/models/modifiers/spasity/sparse_utils/csr_add_vs_scatter_add.py index 1f6888096..1a104404e 100644 --- a/mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py +++ b/mttl/models/modifiers/spasity/sparse_utils/csr_add_vs_scatter_add.py @@ -7,8 +7,8 @@ from spops import csr_add, sddmm from triton.ops.blocksparse import matmul -from mttl.models.modifiers.sparse_mask import SparseMaskConfig, SparseWeights from mttl.models.modifiers.sparse_utils.utils import init_sparse_weights +from mttl.models.modifiers.spasity.sparse_mask import SparseMaskConfig, SparseWeights n_blocks = 8 BLOCK_SIZE = 128 diff --git a/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py b/mttl/models/modifiers/spasity/sparse_utils/profile_adapter_merging.py similarity index 99% rename from mttl/models/modifiers/sparse_utils/profile_adapter_merging.py rename to mttl/models/modifiers/spasity/sparse_utils/profile_adapter_merging.py index d7bd3d3d1..1da9d168e 100644 --- a/mttl/models/modifiers/sparse_utils/profile_adapter_merging.py +++ b/mttl/models/modifiers/spasity/sparse_utils/profile_adapter_merging.py @@ -18,7 +18,7 @@ from mttl.models.modifiers import modify_transformer from mttl.models.modifiers.base import Modifier from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig -from mttl.models.modifiers.sparse_mask import ( +from mttl.models.modifiers.spasity.sparse_mask import ( MaskedLinear, ScatteredSparseLinearModule, SparseLinearModule, @@ -29,8 +29,8 @@ device = "cuda" logger.setLevel(logging.ERROR) -block_size = 128 # 16 -n_blocks = 16 # 1024 +block_size = 16 # 128 # 16 +n_blocks = 1024 # 16 # 1024 in_d = 2048 out_d = 8192 @@ -512,7 +512,7 @@ def prepare_triton_bs_op(W, op_mode): @tn.testing.perf_report( tn.testing.Benchmark( x_names=["K"], # Argument names to use as an x-axis for the plot. - x_vals=[2, 3, 4], # Different possible values for `x_name`. + x_vals=[2, 3, 4, 10, 64, 128], # Different possible values for `x_name`. x_log=False, # x axis is logarithmic. line_arg="provider", # Argument name whose value corresponds to a different line in the plot. line_vals=[ diff --git a/mttl/models/modifiers/sparse_utils/profile_block_sparsity.py b/mttl/models/modifiers/spasity/sparse_utils/profile_block_sparsity.py similarity index 96% rename from mttl/models/modifiers/sparse_utils/profile_block_sparsity.py rename to mttl/models/modifiers/spasity/sparse_utils/profile_block_sparsity.py index 58a08fc6b..446f8ed44 100644 --- a/mttl/models/modifiers/sparse_utils/profile_block_sparsity.py +++ b/mttl/models/modifiers/spasity/sparse_utils/profile_block_sparsity.py @@ -9,8 +9,8 @@ from spops import csr_add, sddmm from triton.ops.blocksparse import matmul -from mttl.models.modifiers.sparse_mask import SparseMaskConfig, SparseWeights from mttl.models.modifiers.sparse_utils.utils import init_sparse_weights +from mttl.models.modifiers.spasity.sparse_mask import SparseMaskConfig, SparseWeights n_blocks = 4 BLOCK_SIZE = 128 @@ -73,9 +73,9 @@ def to_block_sparse_layout(matrix: torch.Tensor, block_size: int): matrix = matrix.flatten(2, 3).sum(dim=-1) return matrix.cpu().bool().to(torch.int64) - layout = to_block_sparse_layout(W, BLOCK_SIZE).unsqueeze(0) + layout = to_block_sparse_layout(W, block_size).unsqueeze(0) # creat inputs - op = matmul(layout, BLOCK_SIZE, op_mode, trans_a=AT, trans_b=BT, device="cuda") + op = matmul(layout, block_size, op_mode, trans_a=AT, trans_b=BT, device="cuda") return op diff --git a/mttl/models/modifiers/sparse_utils/profile_sparse_mask.py b/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask.py similarity index 99% rename from mttl/models/modifiers/sparse_utils/profile_sparse_mask.py rename to mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask.py index 39caa85a7..70584829a 100644 --- a/mttl/models/modifiers/sparse_utils/profile_sparse_mask.py +++ b/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask.py @@ -10,7 +10,7 @@ from mttl.logging import logger from mttl.models.modifiers import modify_transformer from mttl.models.modifiers.lora import LoRAConfig -from mttl.models.modifiers.sparse_mask import ( +from mttl.models.modifiers.spasity.sparse_mask import ( MaskedLinear, ScatteredSparseLinearModule, SparseLinearModule, diff --git a/mttl/models/modifiers/sparse_utils/profile_sparse_mask_only_linear.py b/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask_only_linear.py similarity index 99% rename from mttl/models/modifiers/sparse_utils/profile_sparse_mask_only_linear.py rename to mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask_only_linear.py index cb91c3be3..619e733b8 100644 --- a/mttl/models/modifiers/sparse_utils/profile_sparse_mask_only_linear.py +++ b/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask_only_linear.py @@ -11,7 +11,7 @@ from mttl.logging import logger from mttl.models.modifiers import modify_transformer from mttl.models.modifiers.lora import LoRA, LoRAConfig -from mttl.models.modifiers.sparse_mask import ( +from mttl.models.modifiers.spasity.sparse_mask import ( MaskedLinear, ScatteredSparseLinearModule, SparseLinearModule, diff --git a/mttl/models/modifiers/sparse_utils/utils.py b/mttl/models/modifiers/spasity/sparse_utils/utils.py similarity index 99% rename from mttl/models/modifiers/sparse_utils/utils.py rename to mttl/models/modifiers/spasity/sparse_utils/utils.py index 4eb617fb7..61a831aef 100644 --- a/mttl/models/modifiers/sparse_utils/utils.py +++ b/mttl/models/modifiers/spasity/sparse_utils/utils.py @@ -283,7 +283,9 @@ def init_sparse_weights(sps_type, keep_ratio, shape, block_size=None): def make_sparse_model_during_training(module, batch): - from mttl.models.modifiers.sparse_mask import SparseMaskAdapter as SparseMaskModule + from mttl.models.modifiers.spasity.sparse_mask import ( + SparseMaskAdapter as SparseMaskModule, + ) for m in module.modules(): if isinstance(m, SparseMaskModule): diff --git a/mttl/models/modifiers/spasity/stk/__init__.py b/mttl/models/modifiers/spasity/stk/__init__.py new file mode 100644 index 000000000..88e313b90 --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/__init__.py @@ -0,0 +1 @@ +# largely inspired/adopted from STK: https://github.com/stanford-futuredata/stk diff --git a/mttl/models/modifiers/spasity/stk/autocast.py b/mttl/models/modifiers/spasity/stk/autocast.py new file mode 100644 index 000000000..6f50ab11a --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/autocast.py @@ -0,0 +1,40 @@ +import functools + +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + + return decorate_bwd diff --git a/mttl/models/modifiers/spasity/stk/functions.py b/mttl/models/modifiers/spasity/stk/functions.py new file mode 100644 index 000000000..9c336bafc --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/functions.py @@ -0,0 +1,54 @@ +import torch + +import mttl.models.modifiers.spasity.stk.triton_kernels as backend +from mttl.models.modifiers.spasity.matrix import Matrix +from mttl.models.modifiers.spasity.stk.autocast import custom_bwd, custom_fwd + + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device, + ) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply + + +class SDD_SpMerge(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx, + lhs, + rhs, + shape, + data, + row_indices, + column_indices, + column_indices_t, + block_offsets_t, + adap_data, + ada_maping, + ): + # note for later: here we will need ofdfsets transpose and offsets for the baclkward pass if we implement it + out = torch.empty(data.shape, dtype=lhs.dtype, device=lhs.device) + backend.sdd_spmerge( + lhs, rhs, shape, out, row_indices, column_indices, adap_data, ada_maping + ) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + raise NotImplementedError + + +sdd_spsmerge = SDD_SpMerge.apply diff --git a/mttl/models/modifiers/spasity/stk/linear_ops.py b/mttl/models/modifiers/spasity/stk/linear_ops.py new file mode 100644 index 000000000..64f8cdf0e --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/linear_ops.py @@ -0,0 +1,45 @@ +import torch + +from mttl.models.modifiers.spasity import Matrix +from mttl.models.modifiers.spasity.stk import functions + + +def sdd_adamerge(a, b, out_topo: Matrix, out_adaps: Matrix, layout): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(out_topo, Matrix) + assert out_topo.is_contiguous() + assert isinstance(out_adaps, Matrix) + assert out_adaps.data.is_contiguous() + assert isinstance(layout, torch.Tensor) + assert layout.is_contiguous() + # essentially merged the adapters into a single Matrix() + assert ( + out_adaps.shape[1] == out_topo.shape[1] + ), "This performs sparse SDD of a and b, the output topo should have the same number of columns as the out_adaps" + assert ( + out_adaps.shape[1] % b.size(1) == 0 + ), "The number of columns in out_adaps should be a multiple of the number of columns in b" + + out = functions.sdd_spsmerge( + a, + b, + out_topo.size(), + out_topo.data, + out_topo.row_indices, + out_topo.column_indices, + out_topo.column_indices_t, + out_topo.block_offsets_t, + out_adaps.data, + layout, + ) + return Matrix( + out_topo.size(), + out, + out_topo.row_indices, + out_topo.column_indices, + out_topo.offsets, + out_topo.column_indices_t, + out_topo.offsets_t, + out_topo.block_offsets_t, + ) diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test.py b/mttl/models/modifiers/spasity/stk/linear_ops_test.py new file mode 100644 index 000000000..86e42c4a8 --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/linear_ops_test.py @@ -0,0 +1,146 @@ +import itertools +import os +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from mttl.models.modifiers.spasity import Matrix +from mttl.models.modifiers.spasity.stk import linear_ops, matrix_ops, random_ops + +# os.environ["TRITON_INTERPRET"] = "1" + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +blocksize = 16 +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.8), + (128, 128, 64, 0.8), + (128, 128, 128, 0.0), + # (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + # (512, 128, 128, 0.0), + # (128, 128, 512, 0.0), + # (1024, 512, 512, 0.0), + # (1024, 512, 512, 0.5), + # (1024, 512, 512, 0.75), + # (512, 512, 1024, 0.0), + # (512, 512, 1024, 0.5), + # (512, 512, 1024, 0.75), + # (1024, 1024, 1024, 0.0), + # (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + # (False, True), + # (True, False), + # (True, True), +) + +_DTYPE = (torch.float16, torch.bfloat16) + + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [ + (*size, *trans, blocksize, dtype) for (size, trans, dtype) in testcases + ] + return testcases + + +_LINEAR_OP_TESTS = _generate_testcases() + + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = random_ops.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = matrix_ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return ( + dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True), + ) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _mmm_with_adapters(a, W_base, topo, adapters): + b = W_base.repeat(1, len(adapters)) + adaps_as_dense = [matrix_ops.to_dense(adap) for adap in adapters] + b = b + torch.cat(adaps_as_dense, dim=1) + mask = matrix_ops.to_dense(matrix_ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + def testLinearOps_Sdd_wAdapters( + self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype + ): + if trans_a or trans_b: + return + # Construct the operands. + # This tests the use-case where we have base weights and a bunch of adapters. We perform SDD of input x with base weights, but block-ssparse adapters are merged into the base weights first. + + a_shape = (m, k) + a, acp = _dense_2x(*a_shape, dtype) + + n_adaps = 10 + adapters = [ + _dense_and_sparse(*(k, n), sparsity, blocking, dtype)[1] + for _ in range(n_adaps) + ] + # merge all adapters into a single sparse Matrix() + adaps: Matrix = matrix_ops.merge_adapters(adapters) + + out_shape = (m, n * n_adaps) + _, out_topo = _dense_and_sparse(*out_shape, sparsity, blocking, dtype) + # create a mapping from out_topo to adaps, indicating whether each out_topo bvlock needs to be merged with an adapter block, and if so which one + layout = matrix_ops.create_ada_layout(adaps) + + w_shape = (k, n) + W_base, W_basecp = _dense_2x(*w_shape, dtype) + # Execute the matmul. + out = linear_ops.sdd_adamerge(a, W_base, out_topo, adaps, layout) + expected_out = _mmm_with_adapters(acp, W_basecp, out_topo, adapters) + + adapters_as_dense = torch.cat( + [matrix_ops.to_dense(adap) for adap in adapters], dim=1 + ) + adaps_as_dense = matrix_ops.to_dense(adaps) + assert ( + torch.sum(adapters_as_dense != adaps_as_dense) == 0 + ), "adapters and adaps should be the same" + + # Validate the results. + out = matrix_ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + +if __name__ == "__main__": + unittest.main() diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops.py b/mttl/models/modifiers/spasity/stk/matrix_ops.py new file mode 100644 index 000000000..9fb3540c4 --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/matrix_ops.py @@ -0,0 +1,195 @@ +from typing import List + +import numpy as np +import torch + +from mttl.models.modifiers.spasity import Matrix +from mttl.models.modifiers.spasity.stk import functions + +# mostly taken/adapter from STK: https://github.com/stanford-futuredata/stk + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return functions.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape( + torch.arange(blocking, device=idxs.device), [1, blocking] + ) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape( + torch.arange(blocking, device=idxs.device), [1, blocking, 1] + ) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix( + x.size(), torch.ones_like(x.data), x.row_indices, x.column_indices, x.offsets + ) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() + + +def merge_adapters(adapters: List[Matrix]) -> Matrix: + """ + Merges a list of adapters into a single adapter along the second dimention. + """ + col_indices_list = [adap.column_indices for adap in adapters] + row_indices_list = [adap.row_indices for adap in adapters] + offsets_list = [adap.offsets for adap in adapters] + data_list = [adap.data for adap in adapters] + + num_rows = [offsets.numel() - 1 for offsets in offsets_list] + assert all( + num_rows[0] == num_rows[i] for i in range(1, len(num_rows)) + ), "All adapters must have the same number of rows" + + block_size = adapters[0].blocking + K, N = adapters[0].size() + col_offset = N // block_size # assuming all have same number of cols + n_adaps = len(adapters) + + adjusted_col_indices = [] + for e, col_idx in enumerate(col_indices_list): + adjusted_col_indices.append(col_idx + e * col_offset) + + merged_col_indices = torch.cat(adjusted_col_indices) + row_indices = torch.cat([adap.row_indices for adap in adapters], dim=0) + data = torch.cat([adap.data for adap in adapters], dim=0) + + indices = torch.stack([row_indices, merged_col_indices], dim=1) + + if indices.is_cuda: + indices = indices.cpu() + + # Convert to NumPy + np_tensor = indices.numpy() + # Perform lexsort: sort by second key first, then first key + sorted_indices = np.lexsort((np_tensor[:, 1], np_tensor[:, 0])) + + data = data[sorted_indices].contiguous() + row_indices = row_indices[sorted_indices].contiguous() + col_indices = merged_col_indices[sorted_indices].contiguous() + + # recalculate offsets + num_rows = max(num_rows) + offsets = torch.zeros(num_rows + 1, dtype=torch.int32, device=row_indices.device) + counts_per_row = torch.bincount(row_indices, minlength=num_rows) + offsets[1:] = torch.cumsum(counts_per_row, dim=0) + offsets = offsets.contiguous() + + return Matrix((K, n_adaps * N), data, row_indices, col_indices, offsets) + + +def create_ada_indices( + row_indices, column_indices, ada_row_indices, ada_column_indices, device +): + """ """ + nnz_blocks = len(row_indices) + ada_block_map = {} + for idx, (r, c) in enumerate(zip(ada_row_indices, ada_column_indices)): + ada_block_map[(r.item(), c.item())] = idx + + ada_indices = torch.full((nnz_blocks,), -1, dtype=torch.int32, device=device) + for pid in range(nnz_blocks): + pid_m = row_indices[pid].item() + pid_n = column_indices[pid].item() + if (pid_m, pid_n) in ada_block_map: + ada_indices[pid] = ada_block_map[(pid_m, pid_n)] + return ada_indices + + +def create_ada_layout(matix: Matrix): + """ + Creates a binary tensor that identifies if block exists in the adapter matrix + """ + block_size = matix.blocking + layout = ( + torch.ones( + (matix.size()[0] // block_size, matix.size()[1] // block_size), + dtype=torch.int32, + device=matix.device, + ) + * -1 + ) + blck = 0 + for r, c in zip(matix.row_indices, matix.column_indices): + layout[r.item(), c.item()] = blck + blck += 1 + return layout.contiguous() diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops_test.py b/mttl/models/modifiers/spasity/stk/matrix_ops_test.py new file mode 100644 index 000000000..44a341696 --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/matrix_ops_test.py @@ -0,0 +1,64 @@ +import unittest + +import torch +from absl.testing import parameterized + +from mttl.models.modifiers.spasity.stk import matrix_ops, random_ops + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, 0.95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, 0.95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, 0.875, 128), +) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = random_ops.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = matrix_ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking**2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == "__main__": + unittest.main() diff --git a/mttl/models/modifiers/spasity/stk/random_ops.py b/mttl/models/modifiers/spasity/stk/random_ops.py new file mode 100644 index 000000000..59a15a3c1 --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/random_ops.py @@ -0,0 +1,37 @@ +import numpy as np +import torch + +from mttl.models.modifiers.spasity.stk import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), (1, blocking, 1, blocking) + ) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/mttl/models/modifiers/spasity/stk/triton_kernels.py b/mttl/models/modifiers/spasity/stk/triton_kernels.py new file mode 100644 index 000000000..f5c1ed0a6 --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/triton_kernels.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass + +import torch +import triton +import triton.language as tl + + +@dataclass +class TritonConfig: + BLOCK_M: int = 16 # 128 + BLOCK_N: int = 16 # 128 + BLOCK_K: int = 16 # 32 + # BLOCK_SIZE: int = 128 # block size in the output matrix? + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config( + { + "BLOCK_M": TritonConfig.BLOCK_M, + "BLOCK_N": TritonConfig.BLOCK_N, + "BLOCK_K": TritonConfig.BLOCK_K, + # "BLOCK_SIZE": TritonConfig.BLOCK_SIZE, + }, + num_stages=TritonConfig.NUM_STAGES, + num_warps=TritonConfig.NUM_WARPS, + ), + ], + key=["M", "N", "K"], # uses these keys to decide wether to re-evaluate the choise of best config +) +@triton.jit # this is understood +def _sdd_adamerge( + A, + B, + S, + OUT, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + row_indices, + column_indices, + layout, + stride_layout_m, + stride_layout_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + # matrix multiplication + pid = tl.program_id(0) # in triton only control thread blocks + pid_m = tl.load( + row_indices + pid + ) # row index of the block in the output matrix that is being computed by this thread block + pid_n = tl.load( + column_indices + pid + ) # column index of the block in the output matrix that is being computed by this thread block + rm = pid_m * BLOCK_M + tl.arange( + 0, BLOCK_M + ) # the actual row indices in the output matrix + rn = pid_n * BLOCK_N + tl.arange( + 0, BLOCK_N + ) # the actual column indices in the output matrix + ram = tl.max_contiguous( + tl.multiple_of(rm % M, BLOCK_M), BLOCK_M + ) # optimizes memory throughput by ensuring that the memory accesses are contiguous + rbn = tl.max_contiguous( + tl.multiple_of(rn % N, BLOCK_N), BLOCK_N + ) # optimizes memory throughput by ensuring that the memory accesses are contiguous + rk = tl.arange(0, BLOCK_K) # innialize inner dimention range for the current block + BLOCK_ELEMENTS = BLOCK_M * BLOCK_N # BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + # pointers + A = A + ( + ram[:, None] * stride_am + rk[None, :] * stride_ak + ) # BLOCK_M x BLOCK_K pointes to the dense matrix A for loading + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + s_blck = tl.load(layout + k * stride_layout_m + pid_n * stride_layout_n) + mask = s_blck >= 0 + s_blck = tl.where(mask, s_blck, 0) + s_ptr = ( + S + + s_blck * BLOCK_ELEMENTS + + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + ) + s = tl.load(s_ptr) + s = tl.where(mask[None, None], s, tl.zeros_like(s)) + b = b + s + acc += tl.dot(a, b) # this should be using tensor cores on A100 + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + # Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + # remember, in OUT we only store the non-zero elements, so no need to map it to dense matrix + OUT = ( + OUT + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + ) + tl.store(OUT, acc, mask=True) + + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + + +def row_indices(shape, data, offsets, column_indices, out): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows,)](offsets, out) + + +def sdd_spmerge( + x, + base_weights, + shape, + out, + row_indices, + column_indices, + ada_data, + ada_layout, # +): + # E is the number of experts + # ada_data is (E x n_blocks_per_e) x block_size x block_size + # base_weights is dense matrix of shape (K, (expert_out_dim x E) + # ada_row_indices is (E x n_blocks_per_e) + # ada_column_indices is (E x n_blocks_per_e) + # base_weights.shape[1 = expert out dim. + + assert x.shape[1] == base_weights.shape[0], "incompatible dimensions" + M, K = x.shape + _, N = base_weights.shape + assert ( + shape[1] & N == 0 + ), "RHS out dimension must be divisible by base weights output dim." + E = shape[1] // N + block_size = ada_data.shape[1] + + _validate_matmul_dims(M, K, N) + + if out.dtype in [torch.float16, torch.bfloat16, torch.float32]: + ACC_TYPE = tl.float32 + else: + raise ValueError(f"Unsupported dtype: {out.dtype}") + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) # this just alunches 61 threadblocks + + stride_am, stride_ak = x.stride(0), x.stride(1) + stride_bk, stride_bn = base_weights.stride(0), base_weights.stride(1) + + _sdd_adamerge[grid]( + x, + base_weights, + ada_data, + out, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + out.stride(1), + out.stride(2), + row_indices, + column_indices, + ada_layout, + ada_layout.stride(0), + ada_layout.stride(1), + ACC_TYPE=ACC_TYPE, + ) diff --git a/tests/test_sparse_masks.py b/tests/test_sparse_masks.py index ca56e0fec..9fe45c684 100644 --- a/tests/test_sparse_masks.py +++ b/tests/test_sparse_masks.py @@ -7,7 +7,7 @@ from pytorch_lightning import seed_everything from mttl.models.modifiers import modify_transformer -from mttl.models.modifiers.sparse_mask import ( +from mttl.models.modifiers.spasity.sparse_mask import ( MaskedLinear, ScatteredSparseLinearModule, SNIPMaskUpdateWrapper, From 4206e840b5a941ef9406eb215ba9ca30007be688 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 27 Sep 2024 10:16:28 -0400 Subject: [PATCH 05/24] require stk --- mttl/models/modifiers/spasity/__init__.py | 2 - mttl/models/modifiers/spasity/matrix.py | 364 ------------------ mttl/models/modifiers/spasity/stk/__init__.py | 1 - mttl/models/modifiers/spasity/stk/autocast.py | 40 -- .../models/modifiers/spasity/stk/functions.py | 4 +- .../modifiers/spasity/stk/linear_ops.py | 2 +- .../modifiers/spasity/stk/linear_ops_test.py | 19 +- .../modifiers/spasity/stk/matrix_ops.py | 121 +----- .../modifiers/spasity/stk/matrix_ops_test.py | 64 --- .../modifiers/spasity/stk/random_ops.py | 37 -- .../modifiers/spasity/stk/triton_kernels.py | 2 +- requirements.txt | 1 + 12 files changed, 16 insertions(+), 641 deletions(-) delete mode 100644 mttl/models/modifiers/spasity/matrix.py delete mode 100644 mttl/models/modifiers/spasity/stk/autocast.py delete mode 100644 mttl/models/modifiers/spasity/stk/matrix_ops_test.py delete mode 100644 mttl/models/modifiers/spasity/stk/random_ops.py diff --git a/mttl/models/modifiers/spasity/__init__.py b/mttl/models/modifiers/spasity/__init__.py index 67fb01f06..49ea5eda6 100644 --- a/mttl/models/modifiers/spasity/__init__.py +++ b/mttl/models/modifiers/spasity/__init__.py @@ -1,3 +1 @@ -# largely inspired/adopted from STK: https://github.com/stanford-futuredata/stk -from mttl.models.modifiers.spasity.matrix import Matrix from mttl.models.modifiers.spasity.sparse_mask import * diff --git a/mttl/models/modifiers/spasity/matrix.py b/mttl/models/modifiers/spasity/matrix.py deleted file mode 100644 index dc6fe8be5..000000000 --- a/mttl/models/modifiers/spasity/matrix.py +++ /dev/null @@ -1,364 +0,0 @@ -import numpy as np -import torch - -# this is copy paste from stk: https://github.com/stanford-futuredata/stk - - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}" - ) - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape." - ) - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking." - ) - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})" - ) - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices." - ) - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices." - ) - - if offsets.dim() != 1: - raise ValueError(f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks" - ) - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks" - ) - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows." - ) - - is_cuda = ( - data.is_cuda - and row_indices.is_cuda - and column_indices.is_cuda - and offsets.is_cuda - ) - is_cpu = ( - not data.is_cuda - and not row_indices.is_cuda - and not column_indices.is_cuda - and not offsets.is_cuda - ) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}." - ) - - if data.dtype != torch.float16: - raise ValueError(f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices." - ) - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices." - ) - if offsets.dtype != torch.int32: - raise ValueError(f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__( - self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None, - ): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ( - (column_indices_t is None) - or (offsets_t is None) - or (block_offsets_t is None) - ): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets - ) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices." - ) - - def validate(self): - _validate_matrix( - self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets, - ) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone(), - ) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D." - ) - out = Matrix( - self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t, - ) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}." - ) - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}" - ) - return Matrix( - shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t, - ) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix( - size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t, - ) - return out if self.is_contiguous() else out.t() diff --git a/mttl/models/modifiers/spasity/stk/__init__.py b/mttl/models/modifiers/spasity/stk/__init__.py index 88e313b90..e69de29bb 100644 --- a/mttl/models/modifiers/spasity/stk/__init__.py +++ b/mttl/models/modifiers/spasity/stk/__init__.py @@ -1 +0,0 @@ -# largely inspired/adopted from STK: https://github.com/stanford-futuredata/stk diff --git a/mttl/models/modifiers/spasity/stk/autocast.py b/mttl/models/modifiers/spasity/stk/autocast.py deleted file mode 100644 index 6f50ab11a..000000000 --- a/mttl/models/modifiers/spasity/stk/autocast.py +++ /dev/null @@ -1,40 +0,0 @@ -import functools - -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - - return decorate_bwd diff --git a/mttl/models/modifiers/spasity/stk/functions.py b/mttl/models/modifiers/spasity/stk/functions.py index 9c336bafc..77271d187 100644 --- a/mttl/models/modifiers/spasity/stk/functions.py +++ b/mttl/models/modifiers/spasity/stk/functions.py @@ -1,8 +1,8 @@ import torch +from stk.backend.autocast import custom_bwd, custom_fwd +from stk.matrix import Matrix import mttl.models.modifiers.spasity.stk.triton_kernels as backend -from mttl.models.modifiers.spasity.matrix import Matrix -from mttl.models.modifiers.spasity.stk.autocast import custom_bwd, custom_fwd class RowIndices(torch.autograd.Function): diff --git a/mttl/models/modifiers/spasity/stk/linear_ops.py b/mttl/models/modifiers/spasity/stk/linear_ops.py index 64f8cdf0e..963908f55 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops.py +++ b/mttl/models/modifiers/spasity/stk/linear_ops.py @@ -1,6 +1,6 @@ import torch +from stk.matrix import Matrix -from mttl.models.modifiers.spasity import Matrix from mttl.models.modifiers.spasity.stk import functions diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test.py b/mttl/models/modifiers/spasity/stk/linear_ops_test.py index 86e42c4a8..dd543f096 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops_test.py +++ b/mttl/models/modifiers/spasity/stk/linear_ops_test.py @@ -3,11 +3,12 @@ import unittest import numpy as np +import stk import torch from absl.testing import parameterized +from stk.matrix import Matrix -from mttl.models.modifiers.spasity import Matrix -from mttl.models.modifiers.spasity.stk import linear_ops, matrix_ops, random_ops +from mttl.models.modifiers.spasity.stk import linear_ops, matrix_ops # os.environ["TRITON_INTERPRET"] = "1" @@ -65,9 +66,9 @@ def _generate_testcases(): def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = random_ops.dense_mask(rows, cols, sparsity, blocking) + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = matrix_ops.to_sparse(dense, blocking) + sparse = stk.ops.to_sparse(dense, blocking) cuda_device = torch.device("cuda") return ( dense.to(cuda_device).requires_grad_(True), @@ -88,9 +89,9 @@ def _dense_2x(rows, cols, dtype): def _mmm_with_adapters(a, W_base, topo, adapters): b = W_base.repeat(1, len(adapters)) - adaps_as_dense = [matrix_ops.to_dense(adap) for adap in adapters] + adaps_as_dense = [stk.ops.to_dense(adap) for adap in adapters] b = b + torch.cat(adaps_as_dense, dim=1) - mask = matrix_ops.to_dense(matrix_ops.ones_like(topo)) + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) return torch.mm(a, b) * mask @@ -127,15 +128,15 @@ def testLinearOps_Sdd_wAdapters( expected_out = _mmm_with_adapters(acp, W_basecp, out_topo, adapters) adapters_as_dense = torch.cat( - [matrix_ops.to_dense(adap) for adap in adapters], dim=1 + [stk.ops.to_dense(adap) for adap in adapters], dim=1 ) - adaps_as_dense = matrix_ops.to_dense(adaps) + adaps_as_dense = stk.ops.to_dense(adaps) assert ( torch.sum(adapters_as_dense != adaps_as_dense) == 0 ), "adapters and adaps should be the same" # Validate the results. - out = matrix_ops.to_dense(out) + out = stk.ops.to_dense(out) self.assertEqual(out.dim(), 2) self.assertEqual(expected_out.size()[0], out.size()[0]) self.assertEqual(expected_out.size()[1], out.size()[1]) diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops.py b/mttl/models/modifiers/spasity/stk/matrix_ops.py index 9fb3540c4..6276ecc4d 100644 --- a/mttl/models/modifiers/spasity/stk/matrix_ops.py +++ b/mttl/models/modifiers/spasity/stk/matrix_ops.py @@ -2,108 +2,7 @@ import numpy as np import torch - -from mttl.models.modifiers.spasity import Matrix -from mttl.models.modifiers.spasity.stk import functions - -# mostly taken/adapter from STK: https://github.com/stanford-futuredata/stk - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return functions.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape( - torch.arange(blocking, device=idxs.device), [1, blocking] - ) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape( - torch.arange(blocking, device=idxs.device), [1, blocking, 1] - ) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix( - x.size(), torch.ones_like(x.data), x.row_indices, x.column_indices, x.offsets - ) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() +from stk.matrix import Matrix def merge_adapters(adapters: List[Matrix]) -> Matrix: @@ -157,24 +56,6 @@ def merge_adapters(adapters: List[Matrix]) -> Matrix: return Matrix((K, n_adaps * N), data, row_indices, col_indices, offsets) -def create_ada_indices( - row_indices, column_indices, ada_row_indices, ada_column_indices, device -): - """ """ - nnz_blocks = len(row_indices) - ada_block_map = {} - for idx, (r, c) in enumerate(zip(ada_row_indices, ada_column_indices)): - ada_block_map[(r.item(), c.item())] = idx - - ada_indices = torch.full((nnz_blocks,), -1, dtype=torch.int32, device=device) - for pid in range(nnz_blocks): - pid_m = row_indices[pid].item() - pid_n = column_indices[pid].item() - if (pid_m, pid_n) in ada_block_map: - ada_indices[pid] = ada_block_map[(pid_m, pid_n)] - return ada_indices - - def create_ada_layout(matix: Matrix): """ Creates a binary tensor that identifies if block exists in the adapter matrix diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops_test.py b/mttl/models/modifiers/spasity/stk/matrix_ops_test.py deleted file mode 100644 index 44a341696..000000000 --- a/mttl/models/modifiers/spasity/stk/matrix_ops_test.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest - -import torch -from absl.testing import parameterized - -from mttl.models.modifiers.spasity.stk import matrix_ops, random_ops - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, 0.95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, 0.95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, 0.875, 128), -) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = random_ops.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = matrix_ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking**2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == "__main__": - unittest.main() diff --git a/mttl/models/modifiers/spasity/stk/random_ops.py b/mttl/models/modifiers/spasity/stk/random_ops.py deleted file mode 100644 index 59a15a3c1..000000000 --- a/mttl/models/modifiers/spasity/stk/random_ops.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np -import torch - -from mttl.models.modifiers.spasity.stk import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), (1, blocking, 1, blocking) - ) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/mttl/models/modifiers/spasity/stk/triton_kernels.py b/mttl/models/modifiers/spasity/stk/triton_kernels.py index f5c1ed0a6..871d8fa58 100644 --- a/mttl/models/modifiers/spasity/stk/triton_kernels.py +++ b/mttl/models/modifiers/spasity/stk/triton_kernels.py @@ -112,7 +112,7 @@ def _sdd_adamerge( A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk # Store to sparse matrix - acc = acc.to(C.dtype.element_ty) + acc = acc.to(OUT.dtype.element_ty) # remember, in OUT we only store the non-zero elements, so no need to map it to dense matrix OUT = ( OUT + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) diff --git a/requirements.txt b/requirements.txt index c6051bc59..b44658852 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,5 @@ azure-storage-blob azure-identity einops nltk +stanford-stk # spops @ git+https://github.com/IST-DASLab/spops.git@main From 334df5eb5546731cfb4eb42f499d612d9d958003 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 3 Oct 2024 15:48:22 -0400 Subject: [PATCH 06/24] scattared implementation --- ...er_merging.py => bsr_ddsloop_benchmark.py} | 0 .../spasity/sparse_utils/bsr_moe_benchmark.py | 244 +++++++ .../sparse_utils/profile_block_sparsity.py | 200 ------ .../sparse_utils/profile_sparse_mask.py | 329 --------- .../profile_sparse_mask_only_linear.py | 352 ---------- .../modifiers/spasity/sparse_utils/utils.py | 138 ++++ .../models/modifiers/spasity/stk/functions.py | 143 ++++ .../modifiers/spasity/stk/linear_ops.py | 130 ++++ ...ps_test.py => linear_ops_test_megatron.py} | 105 ++- .../spasity/stk/linear_ops_test_scatter.py | 143 ++++ .../modifiers/spasity/stk/matrix_ops.py | 36 +- .../modifiers/spasity/stk/matrix_ops_test.py | 39 ++ .../modifiers/spasity/stk/measure_time.py | 180 +++++ .../spasity/stk/scatter_moe_kernels.py | 622 ++++++++++++++++++ .../modifiers/spasity/stk/triton_kernels.py | 539 +++++++++++++++ 15 files changed, 2314 insertions(+), 886 deletions(-) rename mttl/models/modifiers/spasity/sparse_utils/{profile_adapter_merging.py => bsr_ddsloop_benchmark.py} (100%) create mode 100644 mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py delete mode 100644 mttl/models/modifiers/spasity/sparse_utils/profile_block_sparsity.py delete mode 100644 mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask.py delete mode 100644 mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask_only_linear.py rename mttl/models/modifiers/spasity/stk/{linear_ops_test.py => linear_ops_test_megatron.py} (54%) create mode 100644 mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py create mode 100644 mttl/models/modifiers/spasity/stk/matrix_ops_test.py create mode 100644 mttl/models/modifiers/spasity/stk/measure_time.py create mode 100644 mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py diff --git a/mttl/models/modifiers/spasity/sparse_utils/profile_adapter_merging.py b/mttl/models/modifiers/spasity/sparse_utils/bsr_ddsloop_benchmark.py similarity index 100% rename from mttl/models/modifiers/spasity/sparse_utils/profile_adapter_merging.py rename to mttl/models/modifiers/spasity/sparse_utils/bsr_ddsloop_benchmark.py diff --git a/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py b/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py new file mode 100644 index 000000000..da6f79d8d --- /dev/null +++ b/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py @@ -0,0 +1,244 @@ +import logging +import re +import time +from typing import List + +import numpy as np +import pandas as pd +import stk.ops +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton as tn +from pytorch_lightning import seed_everything +from spops import csr_add, spmm +from stk.matrix import Matrix +from triton.ops.blocksparse import matmul + +from mttl.logging import logger +from mttl.models.modifiers import modify_transformer +from mttl.models.modifiers.base import Modifier +from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig +from mttl.models.modifiers.spasity.sparse_mask import ( + MaskedLinear, + ScatteredSparseLinearModule, + SparseLinearModule, + SparseMaskAdapter, + SparseMaskConfig, +) +from mttl.models.modifiers.spasity.sparse_utils.utils import ( + padded_gather, + padded_scatter, +) +from mttl.models.modifiers.spasity.stk import linear_ops, matrix_ops +from mttl.models.utils import model_loader_helper, transfer_batch_to_device + +device = "cuda" +logger.setLevel(logging.ERROR) +block_size = 16 # 128 # 16 +n_blocks = 1024 # 16 # 1024 +in_d = 2048 +out_d = 8192 +dtype = torch.bfloat16 +max_seq_len = 1024 +bs = 2 +layer = nn.Linear(in_d, out_d).to(device) +layer.weight.requires_grad_(False) +layer.bias.requires_grad_(False) +K = 100 +top_k = 2 + + +def calculate_lora_parameters(input_dim, output_dim, rank): + return input_dim * rank + output_dim * rank + + +def find_hyperpaams(): + modules = {"linear": layer} + modified_modules = {} + keep_ratios = [] + lora_ranks = [] + + for name, module in modules.items(): + keep_ratio = ( + n_blocks * (block_size**2) / (module.in_features * module.out_features) + ) + tot_sparse_params = module.in_features * module.out_features * keep_ratio + lora_rank = 1 + for rank in range(1, module.in_features): + lora_params = calculate_lora_parameters( + module.in_features, module.out_features, rank + ) + if lora_params <= tot_sparse_params: + lora_rank = rank + else: + break + modified_modules[name] = { + "module": module, + "keep_ratio": keep_ratio, + "lora_rank": lora_rank, + } + keep_ratios.append(keep_ratio) + lora_ranks.append(lora_rank) + return np.mean(keep_ratios), int(np.mean(lora_ranks)) + + +keep_ratio, lora_rank = find_hyperpaams() +print( + f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}, Lora params: {calculate_lora_parameters(in_d, out_d, lora_rank)}, Sparse params: {in_d * out_d * keep_ratio}" +) +x = torch.randn(bs, max_seq_len, in_d, dtype=dtype, device=device).contiguous() + + +def create_adapter_set(adapter_config, layer, K) -> List[Modifier]: + if isinstance(adapter_config, SparseMaskConfig): + layer = nn.Linear(out_d, in_d) # TODO: implement transpose in SparseWeights + module = [SparseMaskAdapter(adapter_config, layer) for _ in range(K)] + elif isinstance(adapter_config, LoRAConfig): + module = [LoRA(adapter_config, layer) for _ in range(K)] + return module + + +def sparsemodules_to_stkmatrix_list(sparse_modules): + sparse_weights = [] + for sparse_module in sparse_modules: + mtx = stk.ops.to_sparse( + sparse_module.sparse_layer.to_dense().type(dtype), blocking=block_size + ) + # mtx.validate() + sparse_weights.append(mtx) + return sparse_weights + + +@torch.autocast(device_type="cuda", dtype=dtype) +def lora_merge(lora_a, lora_b, x, W_base, W_merge): + + # merge into 1 loa + A = torch.einsum("ble,edr->bldr", (W_merge, lora_a)) + B = torch.einsum("ble,erd->blrd", (W_merge, lora_b)) + # lora forward + partial_out = torch.einsum("bld,bldr->blr", (x, A)) + adapter_out = torch.einsum("blr,blrd->bld", (partial_out, B)) + dense_out = x @ W_base + return adapter_out + dense_out + + +def create_block_diagonal_matrix(bs_m, bs_n, n_blocks): + assert bs_m >= block_size + assert bs_n >= block_size + factor = (bs_m * bs_n) // (block_size**2) + + M = bs_m * n_blocks + N = bs_n * n_blocks + + Mb = M // block_size + Nb = N // block_size + + nb_m_pb = bs_m // block_size + nb_n_pb = bs_n // block_size + + col_indices_1blk = torch.arange(nb_n_pb, device=device, dtype=torch.int32).repeat(nb_m_pb) + row_indices_1blk = torch.arange(nb_m_pb, device=device, dtype=torch.int32).repeat_interleave(nb_n_pb) + offsets = torch.arange(0, Mb * nb_n_pb + nb_n_pb, nb_n_pb, device=device) + + col_idx = torch.cat([col_indices_1blk + i * nb_n_pb for i in range(n_blocks)]) + row_idx = torch.cat([row_indices_1blk + i * nb_m_pb for i in range(n_blocks)]) + data = torch.empty((Mb * Nb, block_size, block_size), device=device) + + return Matrix((M, N), data, row_idx, col_idx, offsets) + + +adapter_config_lora = LoRAConfig(modify_layers="", lora_rank=lora_rank) +adapter_config_bs = SparseMaskConfig( + sps_impl="scattered", + sps_type="block_sparse", + keep_ratio=keep_ratio, + reselection_steps=1, + block_size=block_size, +) + +# FOWARD PASS through MoE +W_mege = torch.randn(bs, max_seq_len, K, dtype=dtype, device=device) +loras = create_adapter_set(adapter_config_lora, layer, K) +sparse_modules = create_adapter_set(adapter_config_bs, layer, K) +sparse_mtxs = sparsemodules_to_stkmatrix_list(sparse_modules) +adaptersMatrix: Matrix = matrix_ops.merge_adapters(sparse_mtxs).to(device) + +W_mege = W_mege.to(dtype=loras[0].lora_a.dtype) +top_k_indices = torch.topk(torch.abs(W_mege), top_k, dim=-1).indices +( + x, + num_tokens_per_expert, + sort_order, + indices_expert_padded, + positions_in_expert_padded, + padding_mask, +) = padded_gather(x, top_k_indices, K) +layout = matrix_ops.create_ada_layout(adaptersMatrix).to(device) + +out_blck_size = x.shape[1] +x = x.reshape(-1, in_d).contiguous() +out_topology = create_block_diagonal_matrix(out_blck_size, out_d, K) +W_base = layer.weight.T.to(dtype=dtype) +output = linear_ops.sdd_adamerge(x, W_base, out_topology, adaptersMatrix, layout) +print(output.shape) +# create output topoly + + +# @tn.testing.perf_report( +# tn.testing.Benchmark( +# x_names=["K"], # Argument names to use as an x-axis for the plot. +# x_vals=[2, 3, 4, 10, 64, 128], # Different possible values for `x_name`. +# x_log=False, # x axis is logarithmic. +# line_arg="provider", # Argument name whose value corresponds to a different line in the plot. +# line_vals=[ +# "lora", +# ], # "lora_compiled", "torch_sparse_compiled"], # Possible values for `line_arg`. +# line_names=[ +# "lora", +# ], # "lora_compiled", "torch_sparse_compiled"], # Label name for the lines. +# styles=[ +# ("blue", "-"), +# ("green", "-"), +# ("orange", "-"), +# ("red", "-"), +# ("purple", "-"), +# ("black", "-"), +# ("brown", "-"), +# ], # Line color and style. +# ylabel="ms", #'GB/s', # Label name for the y-axis. +# xlabel="K", +# plot_name="matmul-performance", # Name for the plot. Used also as a file name for saving the plot. +# args={"bs": bs, "max_seq_len": max_seq_len, "in_d": in_d, "d_out": out_d}, +# ) +# ) +# def benchmark(K, bs, max_seq_len, in_d, d_out, provider): +# W_mege = torch.randn(bs, max_seq_len, K, dtype=dtype, device=device) +# loras = create_adapter_set(adapter_config_lora, layer, K) +# sparse_modules = create_adapter_set(adapter_config_bs, layer, K) +# W_mege = W_mege.to(dtype=loras[0].lora_a.dtype) + +# lora_a = torch.stack([lora.lora_a for lora in loras], dim=0) +# lora_b = torch.stack([lora.lora_b for lora in loras], dim=0) +# sparse_weights: List[torch.Tensor] = [ +# sparse_module.sparse_layer.to_dense().to_sparse_csr().to(device) +# for sparse_module in sparse_modules +# ] +# sparse_weights_spops = [ +# sparse_module.sparse_layer.to(device) for sparse_module in sparse_modules +# ] + +# print("Testing provider:", provider, "K:", K) +# quantiles = [0.5, 0.2, 0.8] +# if provider == "lora": +# ms, min_ms, max_ms = tn.testing.do_bench( +# lambda: lora_merge(lora_a, lora_b, x, layer.weight.T, W_mege), +# quantiles=quantiles, +# ) + +# # gbps = lambda ms: 2 * s * h * o * 2 * 1e-9 / (ms * 1e-3) +# # return gbps(ms), gbps(max_ms), gbps(min_ms) +# return ms, max_ms, min_ms + + +# benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/mttl/models/modifiers/spasity/sparse_utils/profile_block_sparsity.py b/mttl/models/modifiers/spasity/sparse_utils/profile_block_sparsity.py deleted file mode 100644 index 446f8ed44..000000000 --- a/mttl/models/modifiers/spasity/sparse_utils/profile_block_sparsity.py +++ /dev/null @@ -1,200 +0,0 @@ -# several options to compare for block sparce operations: -# 1. triton.ops.blocksparse -- this is supposed to work fast for cases whee the sparcity structure is not changing too often -# 2. stk -- https://github.com/stanford-futuredata/stk -- this is supposed to work fast for cases where the sparcity structure is changing often -import stk -import stk.ops -import torch -import torch.nn.functional as F -import triton as tn -from spops import csr_add, sddmm -from triton.ops.blocksparse import matmul - -from mttl.models.modifiers.sparse_utils.utils import init_sparse_weights -from mttl.models.modifiers.spasity.sparse_mask import SparseMaskConfig, SparseWeights - -n_blocks = 4 -BLOCK_SIZE = 128 -dtype = torch.float16 - -sequence_length = 1024 -hidden_size = 2048 # phi2 size -mlp_size = 8192 - -sparcity = n_blocks * (BLOCK_SIZE**2) / (hidden_size * mlp_size) -print(f"sparsity: {sparcity}") - -# W = init_sparse_weights("block_sparse", 0.005, (K, N), BLOCK_SIZE).contiguous().to('cuda') -# X = torch.randn(M, K).to('cuda').contiguous() - - -def stk_sdd(X, W, topo): - return stk.ops.sdd(X, W, topo) - - -def torch_linear(X, W): - return F.linear(X, W) - - -def spops_sdd_structure_aware(X, W, topo: SparseWeights): - return sddmm(topo.row_offs, topo.row_idx, topo.col_idx, X, W) - - -def spops_sdd_sputnik(X, W, topo: SparseWeights): - return sddmm(topo.row_offs, topo.row_idx, topo.col_idx, X, W, backend="sputnik") - - -def torch_linear_w_sparse(X, W): - return F.linear(X, W) - - -def triton_blocksparse_mm(X, W, op): - return op(X, W) - - -def prepare_triton_bs_op(X, W): - Z, H = 1, 1 - AT = False - BT = False - op_mode = "sdd" - - def to_block_sparse_layout(matrix: torch.Tensor, block_size: int): - """ - Returns layout of block sparse matrix: i.e. a matrix of shape (M//block_size, N//block_size) where each element is a boolean indicating whether the block is non-zero. - """ - M, N = matrix.shape - assert M % block_size == 0, "M must be divisible by block_size" - assert N % block_size == 0, "N must be divisible by block_size" - matrix = matrix.reshape( - M // block_size, - block_size, - N // block_size, - block_size, - ).permute(0, 2, 1, 3) - matrix = matrix.flatten(2, 3).sum(dim=-1) - return matrix.cpu().bool().to(torch.int64) - - layout = to_block_sparse_layout(W, block_size).unsqueeze(0) - # creat inputs - op = matmul(layout, block_size, op_mode, trans_a=AT, trans_b=BT, device="cuda") - return op - - -# # adapted from https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html -@tn.testing.perf_report( - tn.testing.Benchmark( - # x_names=['o'], # Argument names to use as an x-axis for the plot. - # x_vals=[128*i for i in [8, 10, 20, 50, 64, 100]], # Different possible values for `x_name`. - x_names=["s"], # Argument names to use as an x-axis for the plot. - x_vals=[ - 128 * i for i in [8, 10, 12, 14, 16, 20] - ], # Different possible values for `x_name`. - x_log=False, # x axis is logarithmic. - line_arg="provider", # Argument name whose value corresponds to a different line in the plot. - line_vals=["naive", "stk", "triton_bs"], # Possible values for `line_arg`. - line_names=["Naive", "stk", "triton_bs"], # Label name for the lines. - styles=[ - ("blue", "-"), - ("green", "-"), - ("orange", "-"), - ("red", "-"), - ("purple", "-"), - ("black", "-"), - ], # Line color and style. - ylabel="ms", #'GB/s', # Label name for the y-axis. - xlabel="seq length dim.", - plot_name="matmul-performance", # Name for the plot. Used also as a file name for saving the plot. - args={ - "h": hidden_size, - "o": mlp_size, - "sp": sparcity, - }, # Values for function arguments not in `x_names` and `y_name`. - ) -) -def benchmark(s, h, o, sp, provider): - X = torch.rand((s, h), device="cuda", dtype=dtype).contiguous() - W = ( - init_sparse_weights("block_sparse", sp, (h, o), BLOCK_SIZE) - .to("cuda") - .to(dtype) - .contiguous() - ) - W_row_sparse = ( - init_sparse_weights("row_sparse", sp, (h, o), BLOCK_SIZE) - .to("cuda") - .to(dtype) - .contiguous() - ) - WT = W.T - assert W.sum() > 0 - assert W_row_sparse.sum() == W.sum() - - quantiles = [0.5, 0.2, 0.8] - if provider == "naive": - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: torch_linear(X, WT), quantiles=quantiles - ) - if provider == "stk": - if BLOCK_SIZE != 128 or dtype != torch.float16: - ms, min_ms, max_ms = 0, 0, 0 - else: - W_stk = stk.ops.to_sparse(W, blocking=BLOCK_SIZE) - W_stk.validate() - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: stk_sdd(X, W, W_stk), quantiles=quantiles - ) - if provider == "torch_bsr": - W_bst = WT.to_sparse_bsr(blocksize=BLOCK_SIZE) - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: torch_linear(X, W_bst), quantiles=quantiles - ) - if provider == "spops_block": - W_spops_block = SparseWeights.from_dense( - W, - SparseMaskConfig( - keep_ratio=sp, block_size=BLOCK_SIZE, sps_type="block_sparse" - ), - ).to("cuda") - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: spops_sdd_structure_aware(X, W, W_spops_block), quantiles=quantiles - ) - if provider == "spops_row": - W_row_sparse = ( - init_sparse_weights("row_sparse", sp, (h, o), BLOCK_SIZE) - .to("cuda") - .to(dtype) - .contiguous() - ) - W_spops_row = SparseWeights.from_dense( - W_row_sparse, - SparseMaskConfig( - keep_ratio=sp, block_size=BLOCK_SIZE, sps_type="row_sparse" - ), - ).to("cuda") - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: spops_sdd_structure_aware(X, W_row_sparse, W_spops_row), - quantiles=quantiles, - ) - if provider == "spops_sputnik_block": - W_spops_block = SparseWeights.from_dense( - W, - SparseMaskConfig( - keep_ratio=sp, block_size=BLOCK_SIZE, sps_type="block_sparse" - ), - ).to("cuda") - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: spops_sdd_sputnik(X, W, W_spops_block), quantiles=quantiles - ) - if provider == "triton_bs": - op = prepare_triton_bs_op(X, W) - X = X[None, None, ...] - W = W[None, None, ...] - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: triton_blocksparse_mm(X, W, op), quantiles=quantiles - ) - - gbps = lambda ms: 2 * s * h * o * 2 * 1e-9 / (ms * 1e-3) - # return gbps(ms), gbps(max_ms), gbps(min_ms) - return ms, max_ms, min_ms - - -benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask.py b/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask.py deleted file mode 100644 index 70584829a..000000000 --- a/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask.py +++ /dev/null @@ -1,329 +0,0 @@ -import logging -import re -import time - -import numpy as np -import pandas as pd -import torch -from pytorch_lightning import seed_everything - -from mttl.logging import logger -from mttl.models.modifiers import modify_transformer -from mttl.models.modifiers.lora import LoRAConfig -from mttl.models.modifiers.spasity.sparse_mask import ( - MaskedLinear, - ScatteredSparseLinearModule, - SparseLinearModule, - SparseMaskConfig, -) -from mttl.models.utils import model_loader_helper, transfer_batch_to_device - -logger.setLevel(logging.ERROR) -model_name = "EleutherAI/gpt-neo-125m" # "EleutherAI/gpt-neo-125m" # "phi-2" -block_size = 128 -n_blocks = 6 -mask_updater = None -modify_layers = ".*q_proj.*|.*v_proj.*|.*k_proj.*" # ".*q_proj.*|.*v_proj.*|.*k_proj.*" # ".*Wqkv.*" # -n_iters = 50 - -# input sizes and batch sizes for testing -max_seq_len = 1024 -bs = 1 -vocab_size = 32000 - - -def calculate_lora_parameters(input_dim, output_dim, rank): - return input_dim * rank + output_dim * rank - - -def find_hyperpaams(): - - model = model_loader_helper( - model_name, - bf16=True, - fp16=False, - load_in_4bit=False, - load_in_8bit=False, - device_map="cpu", - ) - modules = dict(model.named_modules()) - modified_modules = {} - keep_ratios = [] - lora_ranks = [] - - for ml in modify_layers.split("|"): - for name, module in modules.items(): - if re.match(ml, name) and ml not in modified_modules: - keep_ratio = ( - n_blocks - * (block_size**2) - / (module.in_features * module.out_features) - ) - tot_sparse_params = ( - module.in_features * module.out_features * keep_ratio - ) - lora_rank = 1 - for rank in range(1, module.in_features): - lora_params = calculate_lora_parameters( - module.in_features, module.out_features, rank - ) - if lora_params <= tot_sparse_params: - lora_rank = rank - else: - break - modified_modules[ml] = { - "module": module, - "keep_ratio": keep_ratio, - "lora_rank": lora_rank, - } - keep_ratios.append(keep_ratio) - lora_ranks.append(lora_rank) - return np.mean(keep_ratios), int(np.mean(lora_ranks)) - - -keep_ratio, lora_rank = find_hyperpaams() -print(f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}") - - -table = pd.DataFrame( - columns=[ - "Av. Runtime", - "Av. Forward time", - "Av. Backward time", - "Allocated Memory", - "Reserved Memory", - "Number of Parameters", - ] -) - - -def dummy_batch(): - torch.manual_seed(0) - batch = { - "input_ids": torch.randint(10, vocab_size, (bs, max_seq_len)), - "labels": torch.randint(10, vocab_size, (bs, max_seq_len)), - } - seq_len = torch.randint(0, max_seq_len, (bs,)) - attn_mask = torch.zeros(bs, max_seq_len, dtype=torch.int32) - attn_mask[torch.arange(bs), seq_len] = 1 - attn_mask = 1 - attn_mask.cumsum(dim=-1) - batch["attention_mask"] = attn_mask - return batch - - -def benchmark_module(module, runs=100): - # Set up inputs - input_data = dummy_batch() - input_data = transfer_batch_to_device(input_data, "cuda") - - # Warm-up to ensure accurate measurement - for _ in range(10): - loss = module(**input_data).loss - loss.backward() - module.zero_grad() - - forward_time_total = 0.0 - backward_time_total = 0.0 - - # Benchmark runs - for _ in range(runs): - # Forward pass timing - torch.cuda.synchronize() - start_time = time.time() - loss = module(**input_data).loss - torch.cuda.synchronize() - forward_time = time.time() - start_time - - # Backward pass timing - torch.cuda.synchronize() - start_time = time.time() - loss.backward() - torch.cuda.synchronize() - backward_time = time.time() - start_time - - # Zero gradients - module.zero_grad() - - # Accumulate times - forward_time_total += forward_time - backward_time_total += backward_time - - avg_forward_time = forward_time_total / runs - avg_backward_time = backward_time_total / runs - avg_runtime = avg_forward_time + avg_backward_time - - # Measure memory usage - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - loss = module(**input_data).loss # Forward pass to record memory - loss.backward() # Backward pass to record memory - memory_allocated = torch.cuda.max_memory_allocated() - memory_reserved = torch.cuda.max_memory_reserved() - - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - return ( - avg_runtime, - avg_forward_time, - avg_backward_time, - memory_allocated, - memory_reserved, - ) - - -def run_benchmark(name, adapter_config): - seed_everything(0) - model = model_loader_helper( - model_name, - bf16=True, - fp16=False, - load_in_4bit=False, - load_in_8bit=False, - device_map="cpu", - ) - modify_transformer(model, adapter_config) - model.to("cuda") - n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - - runtime, forward_time, backward_time, sparse_alloc, sparse_reserved = ( - benchmark_module(model, runs=n_iters) - ) - print( - f"{name} - Runtime: {runtime:.6f}s, Allocated Memory: {sparse_alloc / 1e6:.2f}MB, Reserved Memory: {sparse_reserved / 1e6:.2f}MB" - ) - table.loc[name] = [ - runtime, - forward_time, - backward_time, - sparse_alloc, - sparse_reserved, - n_params, - ] - - -############################################################################################################################################################ -# Benchmarking LoRA - -adapter_config = LoRAConfig(modify_layers=modify_layers, lora_rank=lora_rank) -run_benchmark("LoRA", adapter_config) - -################################################################################################################################################################# -# Benchmarking BlcockSparseLinearModule - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="triton_block_sparse", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("BlockSparseLinearModule", adapter_config) - - -################################################################################################################################################################# -# Benchmarking SparseLinearModule - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="sp_add+sp_mm", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - # mask_updater=mask_updater, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("SparseLinearModule (reg sp.)", adapter_config) - -############################################################################################################################################################ -# Benchmarking SparseLinearModule with block sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="sp_add+sp_mm", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("SparseLinearModule (block sp.)", adapter_config) - -################################################################################################################################################################# -# Benchmarking SPiEL with regular sparsity kernel - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="spiel", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("Spiel Linear (reg. sp)", adapter_config) - - -################################################################################################################################################################# -# Benchmarking MaskedLinear with regular sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="masked_linear", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("MaskedLinear (reg. sp)", adapter_config) - -############################################################################################################################################################ -# Benchmarking MaskedLinear with block sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="masked_linear", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("MaskedLinear (block sp.)", adapter_config) - -################################################################################################################################################################# -# Benchmarking ScatteredSparseLinearModule - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="scattered", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("ScatteredSparseLinearModule (block sp.)", adapter_config) - -############################################################################################################################################################ -# Benchmarking ScatteredSparseLinearModule with regular sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="scattered", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("ScatteredSparseLinearModule (reg sp.)", adapter_config) - -############################################################################################################################################################ -print(table) -# write table to a csv file -table.to_csv("benchmark_results.csv") diff --git a/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask_only_linear.py b/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask_only_linear.py deleted file mode 100644 index 619e733b8..000000000 --- a/mttl/models/modifiers/spasity/sparse_utils/profile_sparse_mask_only_linear.py +++ /dev/null @@ -1,352 +0,0 @@ -import logging -import re -import time - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -from pytorch_lightning import seed_everything - -from mttl.logging import logger -from mttl.models.modifiers import modify_transformer -from mttl.models.modifiers.lora import LoRA, LoRAConfig -from mttl.models.modifiers.spasity.sparse_mask import ( - MaskedLinear, - ScatteredSparseLinearModule, - SparseLinearModule, - SparseMaskAdapter, - SparseMaskConfig, -) -from mttl.models.utils import model_loader_helper, transfer_batch_to_device - -logger.setLevel(logging.ERROR) -model_name = "EleutherAI/gpt-neo-125m" # "EleutherAI/gpt-neo-125m" # "phi-2" -block_size = 64 -n_blocks = 128 -mask_updater = None -modify_layers = ".*q_proj.*|.*v_proj.*|.*k_proj.*" # ".*q_proj.*|.*v_proj.*|.*k_proj.*" # ".*Wqkv.*" # -n_iters = 50 - -in_d = 2048 -out_d = 8192 * 2 -dtype = torch.bfloat16 - -# input sizes and batch sizes for testing -max_seq_len = 1024 -bs = 5 -vocab_size = 32000 - - -def calculate_lora_parameters(input_dim, output_dim, rank): - return input_dim * rank + output_dim * rank - - -layer = nn.Linear(in_d, out_d) -layer.weight.requires_grad_(False) -layer.bias.requires_grad_(False) - - -def find_hyperpaams(): - modules = {"linear": layer} - modified_modules = {} - keep_ratios = [] - lora_ranks = [] - - for name, module in modules.items(): - keep_ratio = ( - n_blocks * (block_size**2) / (module.in_features * module.out_features) - ) - tot_sparse_params = module.in_features * module.out_features * keep_ratio - lora_rank = 1 - for rank in range(1, module.in_features): - lora_params = calculate_lora_parameters( - module.in_features, module.out_features, rank - ) - if lora_params <= tot_sparse_params: - lora_rank = rank - else: - break - modified_modules[name] = { - "module": module, - "keep_ratio": keep_ratio, - "lora_rank": lora_rank, - } - keep_ratios.append(keep_ratio) - lora_ranks.append(lora_rank) - return np.mean(keep_ratios), int(np.mean(lora_ranks)) - - -keep_ratio, lora_rank = find_hyperpaams() -print(f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}") - - -table = pd.DataFrame( - columns=[ - "Av. Runtime", - "Av. Forward time", - "Av. Backward time", - "Allocated Memory", - "Reserved Memory", - "Number of Parameters", - ] -) - - -def dummy_batch(): - torch.manual_seed(0) - batch = { - "input_ids": torch.randint(10, vocab_size, (bs, max_seq_len)), - "labels": torch.randint(10, vocab_size, (bs, max_seq_len)), - } - seq_len = torch.randint(0, max_seq_len, (bs,)) - attn_mask = torch.zeros(bs, max_seq_len, dtype=torch.int32) - attn_mask[torch.arange(bs), seq_len] = 1 - attn_mask = 1 - attn_mask.cumsum(dim=-1) - batch["attention_mask"] = attn_mask - return batch - - -def benchmark_module(module, runs=100): - # Set up inputs - input_data = dummy_batch() - input_data = torch.rand(bs, max_seq_len, in_d).to("cuda").to(dtype) - - # Warm-up to ensure accurate measurement - for _ in range(10): - out = module(input_data) - loss = torch.mean(out) - loss.backward() - module.zero_grad() - - forward_time_total = 0.0 - backward_time_total = 0.0 - - # Benchmark runs - for _ in range(runs): - # Forward pass timing - torch.cuda.synchronize() - start_time = time.time() - out = module(input_data) - loss = torch.mean(out) - torch.cuda.synchronize() - forward_time = time.time() - start_time - - # Backward pass timing - torch.cuda.synchronize() - start_time = time.time() - loss.backward() - torch.cuda.synchronize() - backward_time = time.time() - start_time - - # Zero gradients - module.zero_grad() - - # Accumulate times - forward_time_total += forward_time - backward_time_total += backward_time - - avg_forward_time = forward_time_total / runs - avg_backward_time = backward_time_total / runs - avg_runtime = avg_forward_time + avg_backward_time - - # Measure memory usage - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - out = module(input_data) # Forward pass to record memory - loss = torch.mean(out) - loss.backward() # Backward pass to record memory - memory_allocated = torch.cuda.max_memory_allocated() - memory_reserved = torch.cuda.max_memory_reserved() - - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - return ( - avg_runtime, - avg_forward_time, - avg_backward_time, - memory_allocated, - memory_reserved, - ) - - -def run_benchmark(name, adapter_config): - seed_everything(0) - if isinstance(adapter_config, LoRAConfig): - module = LoRA(adapter_config, layer) - else: - module = SparseMaskAdapter(adapter_config, layer) - - module.to("cuda").to(dtype) - n_params = sum(p.numel() for p in module.parameters() if p.requires_grad) - - runtime, forward_time, backward_time, sparse_alloc, sparse_reserved = ( - benchmark_module(module, runs=n_iters) - ) - print( - f"{name} - Runtime: {runtime:.6f}s, Allocated Memory: {sparse_alloc / 1e6:.2f}MB, Reserved Memory: {sparse_reserved / 1e6:.2f}MB, Number of Parameters: {n_params}" - ) - table.loc[name] = [ - runtime, - forward_time, - backward_time, - sparse_alloc, - sparse_reserved, - n_params, - ] - - -############################################################################################################################################################ -# Benchmarking LoRA - -adapter_config = LoRAConfig(modify_layers=modify_layers, lora_rank=lora_rank) -run_benchmark("LoRA", adapter_config) - -################################################################################################################################################################# -# Benchmarking BlcockSparseLinearModule + Dense without spoops - -# adapter_config = SparseMaskConfig( -# modify_layers=modify_layers, -# sps_impl="dense+triton_block_sparse", -# sps_type="block_sparse", -# keep_ratio=keep_ratio, -# reselection_steps=1, -# block_size=block_size, -# ) -# run_benchmark("BlockSparseLinearModule + Dense", adapter_config) - -################################################################################################################################################################# -# Benchmarking BlcockSparseLinearModule without spoops - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="triton_block_sparse_scatter", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("BlockSparseLinearModule (scatter add)", adapter_config) - -################################################################################################################################################################# -# Benchmarking BlcockSparseLinearModule - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="triton_block_sparse", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("BlockSparseLinearModule", adapter_config) - - -################################################################################################################################################################# -# Benchmarking SparseLinearModule - - -# adapter_config = SparseMaskConfig( -# modify_layers=modify_layers, -# sps_impl="sp_add+sp_mm", -# sps_type="regular_sparse", -# keep_ratio=keep_ratio, -# # mask_updater=mask_updater, -# reselection_steps=1, -# block_size=block_size, -# ) -# run_benchmark("SparseLinearModule (reg sp.)", adapter_config) - -# ############################################################################################################################################################ -# # Benchmarking SparseLinearModule with block sparsity - - -# adapter_config = SparseMaskConfig( -# modify_layers=modify_layers, -# sps_impl="sp_add+sp_mm", -# sps_type="block_sparse", -# keep_ratio=keep_ratio, -# reselection_steps=1, -# block_size=block_size, -# ) -# run_benchmark("SparseLinearModule (block sp.)", adapter_config) - -################################################################################################################################################################# -# Benchmarking SPiEL with regular sparsity kernel - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="spiel", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("Spiel Linear (reg. sp)", adapter_config) - - -################################################################################################################################################################# -# Benchmarking MaskedLinear with regular sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="masked_linear", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("MaskedLinear (reg. sp)", adapter_config) - -############################################################################################################################################################ -# Benchmarking MaskedLinear with block sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="masked_linear", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("MaskedLinear (block sp.)", adapter_config) - -################################################################################################################################################################# -# Benchmarking ScatteredSparseLinearModule - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="scattered", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("ScatteredSparseLinearModule (block sp.)", adapter_config) - -############################################################################################################################################################ -# Benchmarking ScatteredSparseLinearModule with regular sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="scattered", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("ScatteredSparseLinearModule (reg sp.)", adapter_config) - -############################################################################################################################################################ -# orer table by Av. Runtime -table = table.sort_values("Av. Runtime") -print(table) -# write table to a csv file -table.to_csv("benchmark_results.csv") diff --git a/mttl/models/modifiers/spasity/sparse_utils/utils.py b/mttl/models/modifiers/spasity/sparse_utils/utils.py index 61a831aef..52756766c 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/utils.py +++ b/mttl/models/modifiers/spasity/sparse_utils/utils.py @@ -529,3 +529,141 @@ def backward(ctx, output_grad): return tuple(grads) else: return (grads[0], None) + tuple(grads[2:]) + + +import torch + + +def padded_gather(x, indices, E, block_size=16): + """ + Permute tokens to group them by expert. + Ensures that the number of tokens per expert is divisible by block_size by adding padding. + Returns additional data for padded_scatter. + """ + batch_size, seq_len, d_model = x.size() + top_k = indices.size(-1) + + # Step 1: Flatten x and indices + x_flat = x.view(-1, d_model) # [batch_size * seq_len, d_model] + indices_flat = indices.view(-1) # [batch_size * seq_len * top_k] + + # Step 2: Expand x to match indices + x_flat_expanded = x_flat.unsqueeze(1).expand(-1, top_k, -1) # [batch_size * seq_len, top_k, d_model] + x_expert = x_flat_expanded.reshape(-1, d_model) # [batch_size * seq_len * top_k, d_model] + + # Step 3: Sort indices and x_expert to group tokens by expert + indices_expert, sort_order = indices_flat.sort() + x_expert_sorted = x_expert[sort_order] + + # Step 4: Compute number of tokens per expert + num_tokens_per_expert = torch.bincount(indices_expert, minlength=E) # [E] + + # Step 5: Compute padded number of tokens per expert + padded_num_tokens_per_expert = ((num_tokens_per_expert + block_size - 1) // block_size) * block_size # [E] + max_tokens_per_expert = padded_num_tokens_per_expert.max().item() + + # Step 6: Compute positions within each expert + def compute_positions_in_group(indices_expert): + unique_indices, counts = indices_expert.unique_consecutive(return_counts=True) + positions_in_expert = torch.cat([torch.arange(count, device=indices_expert.device) for count in counts]) + return positions_in_expert + + positions_in_expert = compute_positions_in_group(indices_expert) + + # Step 7: Pad the tokens per expert to make counts divisible by block_size + # For each expert, determine padding needed + padding_needed = padded_num_tokens_per_expert - num_tokens_per_expert # [E] + + indices_expert_padded = [] + positions_in_expert_padded = [] + x_expert_padded = [] + padding_mask = [] + + current_idx = 0 + for e in range(E): + count = num_tokens_per_expert[e].item() + padded_count = padded_num_tokens_per_expert[e].item() + padding = padding_needed[e].item() + + # Get the indices and positions for the current expert + indices_e = indices_expert[current_idx:current_idx+count] + positions_e = positions_in_expert[current_idx:current_idx+count] + x_expert_e = x_expert_sorted[current_idx:current_idx+count] + + # Append original tokens + indices_expert_padded.append(indices_e) + positions_in_expert_padded.append(positions_e) + x_expert_padded.append(x_expert_e) + padding_mask.append(torch.ones(count, dtype=torch.bool, device=indices_expert.device)) + + # If padding is needed, duplicate the last token 'padding' times + if padding > 0: + indices_e_pad = indices_e.new_full((padding,), fill_value=e) + positions_e_pad = positions_e.new_tensor(range(count, padded_count)) + # For x_expert, duplicate the last token + x_expert_e_pad = x_expert_e[-1:].expand(padding, -1) # Duplicate last token + + indices_expert_padded.append(indices_e_pad) + positions_in_expert_padded.append(positions_e_pad) + x_expert_padded.append(x_expert_e_pad) + padding_mask.append(torch.zeros(padding, dtype=torch.bool, device=indices_expert.device)) + + current_idx += count + + # Concatenate all the padded indices, positions, tokens, and mask + indices_expert_padded = torch.cat(indices_expert_padded, dim=0) + positions_in_expert_padded = torch.cat(positions_in_expert_padded, dim=0) + x_expert_padded = torch.cat(x_expert_padded, dim=0) + padding_mask = torch.cat(padding_mask, dim=0) # [total_padded_tokens] + + # Step 8: Initialize output tensor + output = x.new_zeros(E, max_tokens_per_expert, d_model) + + # Step 9: Assign tokens to output tensor + output[indices_expert_padded, positions_in_expert_padded] = x_expert_padded + + # Return additional information for padded_scatter + return output, num_tokens_per_expert, sort_order, indices_expert_padded, positions_in_expert_padded, padding_mask + + +def padded_scatter(x, num_tokens_per_expert, sort_order, batch_size, seq_len, top_k, d_model): + """ + Un-permute tokens back to their original positions. + + Args: + x (torch.Tensor): Input tensor of shape [E, max_tokens_per_expert, d_model], outputs from experts. + num_tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert, shape [E]. + sort_order (torch.Tensor): The sort order used in padded_gather. + batch_size (int): Original batch size. + seq_len (int): Original sequence length. + top_k (int): Number of experts per token. + d_model (int): Model dimension. + + Returns: + output (torch.Tensor): Un-permuted tokens, shape [batch_size, seq_len, top_k, d_model]. + """ + E, max_tokens_per_expert, _ = x.size() + device = x.device + + # Step 1: Flatten x and remove padding + x_flat = x.view(-1, d_model) # [E * max_tokens_per_expert, d_model] + + # Step 2: Build indices for valid tokens + expert_indices = torch.repeat_interleave(torch.arange(E, device=device), num_tokens_per_expert) + positions_in_expert = torch.cat([torch.arange(n, device=device) for n in num_tokens_per_expert]) + valid_positions = expert_indices * max_tokens_per_expert + positions_in_expert + + # Step 3: Select valid tokens + x_valid = x_flat[valid_positions] + + # Step 4: Reconstruct x_expert_sorted + x_expert_sorted = x_valid + + # Step 5: Reconstruct x_expert using inverse of sort_order + x_expert = torch.empty((batch_size * seq_len * top_k, d_model), device=device, dtype=x.dtype) + x_expert[sort_order] = x_expert_sorted + + # Step 6: Reshape to [batch_size, seq_len, top_k, d_model] + x_unpermuted = x_expert.view(batch_size, seq_len, top_k, d_model) + + return x_unpermuted diff --git a/mttl/models/modifiers/spasity/stk/functions.py b/mttl/models/modifiers/spasity/stk/functions.py index 77271d187..1af296089 100644 --- a/mttl/models/modifiers/spasity/stk/functions.py +++ b/mttl/models/modifiers/spasity/stk/functions.py @@ -1,8 +1,14 @@ +from typing import Any + import torch from stk.backend.autocast import custom_bwd, custom_fwd from stk.matrix import Matrix import mttl.models.modifiers.spasity.stk.triton_kernels as backend +from mttl.models.modifiers.spasity.stk.scatter_moe_kernels import ( + scatter2scatter_sparse, + scatter2scatter_sparse_optimized, +) class RowIndices(torch.autograd.Function): @@ -52,3 +58,140 @@ def backward(ctx, dy): sdd_spsmerge = SDD_SpMerge.apply + + +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return backend.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = backend.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply + + +class ParalleLinear(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx, + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates + ): + + output = scatter2scatter_sparse( + X=x, + base_act=base_act, + ada_weights=ada_weights, + row_idxs=row_idxs, + col_idxs_t=col_idxs, + offsets_t=offsets, + block_offsets_t=block_offsets_t, + ada_block=ada_block_size, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + k=k, + gates=gates, + ) + output = output.view(gates.size(0), gates.size(1), output.size(-1)).sum( + 1 + ) # this can be moved into kernel? + return output + + +parallel_linear = ParalleLinear.apply + + + +class ParalleLinear2(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx, + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates=None, + ): + + output = scatter2scatter_sparse_optimized( + X=x, + base_act=base_act, + ada_weights=ada_weights, + row_idxs=row_idxs, + col_idxs_t=col_idxs, + offsets_t=offsets, + block_offsets_t=block_offsets_t, + ada_block=ada_block_size, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + k=k, + gates=gates, + ) + output = output.view(gates.size(0), gates.size(1), output.size(-1)).sum( + 1 + ) # this can be moved into kernel? + return output + + +parallel_linear_optimized = ParalleLinear2.apply \ No newline at end of file diff --git a/mttl/models/modifiers/spasity/stk/linear_ops.py b/mttl/models/modifiers/spasity/stk/linear_ops.py index 963908f55..5a6bbbd69 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops.py +++ b/mttl/models/modifiers/spasity/stk/linear_ops.py @@ -43,3 +43,133 @@ def sdd_adamerge(a, b, out_topo: Matrix, out_adaps: Matrix, layout): out_topo.offsets_t, out_topo.block_offsets_t, ) + + +def scattergather_adamerge( + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, +): + + out = functions.parallel_linear( + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, + ) + return out + +def scattergather_adamerge2( + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, +): + + out = functions.parallel_linear_optimized( + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, + ) + return out + + +BLOCK_M = 128 # expert token capacity + + +@torch.jit.script +def flatten_and_sort(expert_idxs: torch.Tensor): + """ + Flattens a tensor of expert indices and sorts the flattened tensor. + + Args: + expert_idxs (torch.Tensor): A tensor containing expert indices. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - sorted_expert_idxs: The flattened and sorted expert indices. + - sorted_scattered_idxs: The indices that would sort the flattened tensor. + """ + flattened_expert_idxs = expert_idxs.flatten() + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) + return sorted_expert_idxs, sorted_scattered_idxs + + +@torch.jit.script +def padded_block_indices( + sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M +): + """ + Compute padded block indices for sorted experts. + + This function calculates the indices of padded blocks for a given set of sorted expert indices. + It ensures that the blocks are padded to a specified block size. + + Args: + sorted_experts_idxs (torch.Tensor): A tensor containing the sorted indices of experts. + k (int): The number of unique experts. + N_BLOCK_SIZE (int, optional): The size of each block. Defaults to BLOCK_M. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - expanded_block_idxs (torch.Tensor): The indices of the expanded blocks. + - expert_boundaries_end (torch.Tensor): The end boundaries of the experts. + """ + expert_counts = torch.bincount(sorted_experts_idxs, minlength=k) + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + block_idxs = torch.arange( + padded_expert_block_end[-1], + dtype=sorted_experts_idxs.dtype, + device=sorted_experts_idxs.device, + ) + block_mask = (block_idxs[:, None] < padded_expert_block_start) | ( + block_idxs[:, None] >= padded_expert_block_end + ) + expanded_block_idxs = ( + N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) + + expert_boundaries_start + ) + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + return expanded_block_idxs, expert_boundaries_end diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test.py b/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py similarity index 54% rename from mttl/models/modifiers/spasity/stk/linear_ops_test.py rename to mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py index dd543f096..15ba7e88b 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops_test.py +++ b/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py @@ -5,10 +5,18 @@ import numpy as np import stk import torch +import torch.nn.functional as F from absl.testing import parameterized +from pytorch_lightning import seed_everything from stk.matrix import Matrix -from mttl.models.modifiers.spasity.stk import linear_ops, matrix_ops +from mttl.models.modifiers.spasity.stk import functions, linear_ops, matrix_ops + +# os.environ["TRITON_INTERPRET"] = "1" + + +# os.environ["TRITON_INTERPRET"] = "1" + # os.environ["TRITON_INTERPRET"] = "1" @@ -143,5 +151,100 @@ def testLinearOps_Sdd_wAdapters( self.assertTrue(allclose(out, expected_out)) +SC_MOE_TEST = { + (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), + (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), + (8, 128, 256, 10, 2, 0.8, 16, torch.float32), +} + +def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): + output = torch.stack( + [ + sum( + base_act[i] + + (expert_p[i, j] * torch.matmul(x[i], (adaps[expert_idxs[i, j]]))) + for j in range(expert_idxs.size(1)) + ) + for i in range(expert_idxs.size(0)) + ], + dim=0, + ) + return output + + +@parameterized.parameters(*SC_MOE_TEST) +class ScatteredMoETest(parameterized.TestCase): + def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): + torch.manual_seed(42) + # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") + logits = torch.randn(bs, E, dtype=dtype) + weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) + X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() + W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() + adaps = [_dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] + adaps_sparse = [adap[1] for adap in adaps] + adaps_dense = [adap[0] for adap in adaps] + ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) + row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) + col_idxs_t = torch.stack( + [adap.column_indices_t for adap in adaps_sparse], dim=0 + ) + offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) + block_offsets_t = torch.stack( + [adap.block_offsets_t for adap in adaps_sparse], dim=0 + ) + + k_weights, expert_idxs = torch.topk(weights, k) + sorted_expert_idxs, sorted_scattered_idxs = linear_ops.flatten_and_sort( + expert_idxs + ) + padded_block_idxs, expert_offsets = linear_ops.padded_block_indices( + sorted_expert_idxs, E + ) + + base_act = torch.matmul(X, W) + + out = functions.parallel_linear( + x=X, + base_act=base_act, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + gates=k_weights, + ) + + + out2 = functions.parallel_linear_optimized( + x=X, + base_act=base_act, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + gates=k_weights, + ) + + + + out_dumb = dumb_forward(base_act, X, k_weights, expert_idxs, adaps_dense) + err_Y = torch.abs(out - out_dumb) + tolerance = 1e-2 + # print(err_Y.max()) + assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() + + if __name__ == "__main__": unittest.main() diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py b/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py new file mode 100644 index 000000000..f1cf80bb7 --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py @@ -0,0 +1,143 @@ +import itertools +import os +import unittest + +import numpy as np +import stk +import torch +import torch.nn.functional as F +from absl.testing import parameterized +from pytorch_lightning import seed_everything +from stk.matrix import Matrix + +from mttl.models.modifiers.spasity.stk import functions, linear_ops, matrix_ops + +# os.environ["TRITON_INTERPRET"] = "1" + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +blocksize = 16 + + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return ( + dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True), + ) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + + +SC_MOE_TEST = { + # (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), + # (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), + (8, 128, 256, 10, 2, 0.8, 16, torch.float32), +} + +def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): + output = torch.stack( + [ + sum( + base_act[i] + + (expert_p[i, j] * torch.matmul(x[i], (adaps[expert_idxs[i, j]]))) + for j in range(expert_idxs.size(1)) + ) + for i in range(expert_idxs.size(0)) + ], + dim=0, + ) + return output + + +@parameterized.parameters(*SC_MOE_TEST) +class ScatteredMoETest(parameterized.TestCase): + def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): + torch.manual_seed(42) + # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") + logits = torch.randn(bs, E, dtype=dtype) + weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) + X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() + W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() + adaps = [_dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] + adaps_sparse = [adap[1] for adap in adaps] + adaps_dense = [adap[0] for adap in adaps] + ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) + row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) + col_idxs_t = torch.stack( + [adap.column_indices_t for adap in adaps_sparse], dim=0 + ) + offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) + block_offsets_t = torch.stack( + [adap.block_offsets_t for adap in adaps_sparse], dim=0 + ) + + k_weights, expert_idxs = torch.topk(weights, k) + sorted_expert_idxs, sorted_scattered_idxs = linear_ops.flatten_and_sort( + expert_idxs + ) + padded_block_idxs, expert_offsets = linear_ops.padded_block_indices( + sorted_expert_idxs, E + ) + + base_act = torch.matmul(X, W) + + # out = linear_ops.scattergather_adamerge( + # x=X, + # base_act=base_act, + # k=k, + # ada_weights=ada_data, + # row_idxs=row_idxs, + # col_idxs=col_idxs_t, + # offsets=offsets_t, + # block_offsets_t=block_offsets_t, + # ada_block_size=blocking, + # sorted_expert_idxs=sorted_expert_idxs, + # sorted_scattered_idxs=sorted_scattered_idxs, + # padded_block_idxs=padded_block_idxs, + # gates=k_weights, + # ) + + + out2 = linear_ops.scattergather_adamerge2( + x=X, + base_act=base_act, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + gates=k_weights, + ) + + + + out_dumb = dumb_forward(base_act, X, k_weights, expert_idxs, adaps_dense) + err_Y = torch.abs(out2 - out_dumb) + tolerance = 1e-2 + print(err_Y.max()) + assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() + + +if __name__ == "__main__": + unittest.main() diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops.py b/mttl/models/modifiers/spasity/stk/matrix_ops.py index 6276ecc4d..0163550a6 100644 --- a/mttl/models/modifiers/spasity/stk/matrix_ops.py +++ b/mttl/models/modifiers/spasity/stk/matrix_ops.py @@ -1,16 +1,19 @@ from typing import List import numpy as np +import stk.ops import torch from stk.matrix import Matrix -def merge_adapters(adapters: List[Matrix]) -> Matrix: +def _merge_adapters (adapters: List[Matrix]) -> Matrix: """ Merges a list of adapters into a single adapter along the second dimention. + Also changes the block size by padding blocks iwht 0s if necessary. + """ - col_indices_list = [adap.column_indices for adap in adapters] - row_indices_list = [adap.row_indices for adap in adapters] + col_indices_list = [adap.column_indices.to(torch.int32) for adap in adapters] + # row_indices_list = [adap.row_indices for adap in adapters] offsets_list = [adap.offsets for adap in adapters] data_list = [adap.data for adap in adapters] @@ -20,6 +23,7 @@ def merge_adapters(adapters: List[Matrix]) -> Matrix: ), "All adapters must have the same number of rows" block_size = adapters[0].blocking + K, N = adapters[0].size() col_offset = N // block_size # assuming all have same number of cols n_adaps = len(adapters) @@ -29,7 +33,9 @@ def merge_adapters(adapters: List[Matrix]) -> Matrix: adjusted_col_indices.append(col_idx + e * col_offset) merged_col_indices = torch.cat(adjusted_col_indices) - row_indices = torch.cat([adap.row_indices for adap in adapters], dim=0) + row_indices = torch.cat( + [adap.row_indices.to(torch.int32) for adap in adapters], dim=0 + ) data = torch.cat([adap.data for adap in adapters], dim=0) indices = torch.stack([row_indices, merged_col_indices], dim=1) @@ -55,6 +61,28 @@ def merge_adapters(adapters: List[Matrix]) -> Matrix: return Matrix((K, n_adaps * N), data, row_indices, col_indices, offsets) +def change_block_size(M: Matrix, new_blk_size) -> Matrix: + raise NotImplementedError("change_block_size is not implemented yet") + return + + + + + + +def merge_adapters(adapters: List[Matrix], blk_size = None) -> Matrix: + """ + Merges a list of adapters into a single adapter along the second dimention. + Also changes the block size by padding blocks iwht 0s if necessary. + + """ + + out = _merge_adapters(adapters) # merges the adapters into a single Matrix() without changing the block size + if blk_size is not None: + out = change_block_size(out, blk_size) + return out + + def create_ada_layout(matix: Matrix): """ diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops_test.py b/mttl/models/modifiers/spasity/stk/matrix_ops_test.py new file mode 100644 index 000000000..fd4da23db --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/matrix_ops_test.py @@ -0,0 +1,39 @@ +import unittest + +import stk +import stk.ops +import torch +from absl.testing import parameterized +from stk.matrix import Matrix + +from mttl.models.modifiers.spasity.stk import matrix_ops + + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return ( + dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True), + ) + +@parameterized.parameters( + (2, 8, 16, 0.5, 1), + (2, 8, 16, 0.5, 4) + ) +class MatrixOpsTest(parameterized.TestCase): + def test_layout_creation(self, K, rows, cols, sparsity, blocking): + adaps = [_dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) for _ in range(K)] + adaps_sparse = [adap[1] for adap in adaps] + # adaps_dense = [adap[0] for adap in adaps] + + merged_adaps_matrix: Matrix = matrix_ops.merge_adapters(adaps_sparse) + layout = matrix_ops.create_ada_layout(merged_adaps_matrix) + assert layout.max() == merged_adaps_matrix.data.size(0) - 1 + + + +if __name__ == '__main__': + unittest.main() diff --git a/mttl/models/modifiers/spasity/stk/measure_time.py b/mttl/models/modifiers/spasity/stk/measure_time.py new file mode 100644 index 000000000..4b2ead74a --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/measure_time.py @@ -0,0 +1,180 @@ +import time +from functools import partial + +import numpy as np +import stk +import torch +import torch.nn.functional as F +from absl.testing import parameterized +from pytorch_lightning import seed_everything +from stk.matrix import Matrix + +from mttl.models.modifiers.spasity.stk import functions, linear_ops, matrix_ops +from mttl.models.modifiers.spasity.stk.linear_ops_test_scatter import ( + _dense_and_sparse, + dumb_forward, +) + + +def benchmark_module(name, function, runs=100): + # Warm-up to ensure accurate measurement + for _ in range(10): + out = function() + + forward_time_total = 0.0 + + # Benchmark runs + for _ in range(runs): + # Forward pass timing + torch.cuda.synchronize() + start_time = time.time() + out = function() + torch.cuda.synchronize() + forward_time = time.time() - start_time + + # Accumulate times + forward_time_total += forward_time + + avg_forward_time = forward_time_total / runs + + # Measure memory usage + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + out = function() # Forward pass to record memory + memory_allocated = torch.cuda.max_memory_allocated() + memory_reserved = torch.cuda.max_memory_reserved() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + print( + f"Average forward time {name}: {avg_forward_time:.6f}s", + f"Memory allocated: {memory_allocated/1024**2:.2f}MB", + f"Memory reserved: {memory_reserved/1024**2:.2f}MB", + ) + + +def calculate_lora_parameters(input_dim, output_dim, rank): + return input_dim * rank + output_dim * rank + +def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): + lora_ranks = [] + lora_rank = 1 + for rank in range(1, d_in): + lora_params = calculate_lora_parameters(d_in, d_out, rank) + if lora_params <= tot_sparse_params: + lora_rank = rank + else: + break + lora_ranks.append(lora_rank) + return int(np.mean(lora_ranks)) + +SC_MOE_TEST = { + # bs, d, h, E, k, sparsity, blocking, dtype + (1024, 2048, 8192, 20, 2, 0.995, 16, torch.float16), + (1024, 2048, 8192, 20, 2, 0.9, 128, torch.float16), + (1024, 2048, 8192, 100, 2, 0.995, 16, torch.float16), + (1024, 2048, 8192, 100, 2, 0.9, 128, torch.float16), + # (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float16), + # (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float16), + # (8, 128, 256, 10, 2, 0.8, 16, torch.float16), +} + + +for bs, d, h, E, k, sparsity, blocking, dtype in SC_MOE_TEST: + print("=====================================================================") + print( + f"***** Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype} *****" + ) + + torch.manual_seed(42) + # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") + logits = torch.randn(bs, E, dtype=dtype) + weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) + X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() + W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() + adaps = [_dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] + adaps_sparse = [adap[1] for adap in adaps] + adaps_dense = [adap[0] for adap in adaps] + ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) + row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) + col_idxs_t = torch.stack([adap.column_indices_t for adap in adaps_sparse], dim=0) + offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) + block_offsets_t = torch.stack( + [adap.block_offsets_t for adap in adaps_sparse], dim=0 + ) + + k_weights, expert_idxs = torch.topk(weights, k) + + def call_with_baseact_and_idxs_computation(X, W, expert_idxs, function, **kwargs): + base_act = torch.matmul(X, W) + sorted_expert_idxs, sorted_scattered_idxs = linear_ops.flatten_and_sort( + expert_idxs + ) + padded_block_idxs, expert_offsets = linear_ops.padded_block_indices( + sorted_expert_idxs, E + ) + return function( + x=X, + base_act=base_act, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + **kwargs, + ) + + # base_act = torch.matmul(X, W) + func = partial( + call_with_baseact_and_idxs_computation, + X=X, + W=W, + expert_idxs=expert_idxs, + function=linear_ops.scattergather_adamerge, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + gates=k_weights, + ) + benchmark_module("BS kernel not optimized", func) + # func_dummb = partial(dumb_forward, base_act=base_act, x=X, expert_p=k_weights, expert_idxs=expert_idxs, adaps=adaps_dense) + # benchmark_module("dummy forward", func_dummb) + + func_opt = partial( + call_with_baseact_and_idxs_computation, + X=X, + W=W, + expert_idxs=expert_idxs, + function=linear_ops.scattergather_adamerge2, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + gates=k_weights, + ) + benchmark_module("BS kernel optimized", func) + lora_rank = find_lora_hyperpaams(d, h, np.prod(ada_data.shape[1:])) + + + def lora_merge(lora_a, lora_b, x, W_base, W_merge): + # LoRA does not profit from lower top-k in this vanila form + # merge into 1 lora + A = torch.einsum("be,edr->bdr", (W_merge, lora_a)) + B = torch.einsum("be,erd->brd", (W_merge, lora_b)) + # lora forward + partial_out = torch.einsum("bd,bdr->br", (x, A)) + adapter_out = torch.einsum("br,brd->bd", (partial_out, B)) + dense_out = x @ W_base + return adapter_out + dense_out + + lora_a = torch.randn(E, d, lora_rank, dtype=dtype).cuda().contiguous() + lora_b = torch.randn(E, lora_rank, h, dtype=dtype).cuda().contiguous() + func_lora = partial(lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights) + # benchmark_module("LoRA merge (our current vanila)", func_lora) \ No newline at end of file diff --git a/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py b/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py new file mode 100644 index 000000000..67dbf2adc --- /dev/null +++ b/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py @@ -0,0 +1,622 @@ +import torch +import triton +import triton.language as tl +from torch.nn import functional as F + +BLOCK_M = 128 + + +def _scatter2scatter_configs(): + return [ + triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=1, num_warps=1), + ] + + +@triton.autotune( + configs=_scatter2scatter_configs(), + key=["M", "N", "K"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + "ADA_BLCKS_PER_TILE_K": lambda args: args["BLOCK_K"] // args["ADA_BLOCK"], + "ADA_BLCKS_PER_TILE_N": lambda args: args["BLOCK_N"] // args["ADA_BLOCK"], + } +) +@triton.jit +def _scatter2scatter( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_wk, + stride_wn, + adaW, # n_exp x ada_block x ada_block + ada_layout, + stride_layout_e, + stride_layout_m, + stride_layout_n, + Y_ptr, + stride_ym, + stride_yn, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT, + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, + y_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ADA_BLOCK: tl.constexpr, + ADA_BLCKS_PER_TILE_K: tl.constexpr, # how many ada blocks in one tile in K direction + ADA_BLCKS_PER_TILE_N: tl.constexpr, # how many ada blocks in one tile in N direction +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv( + N, BLOCK_N + ) # is 2? numbe of blocks per expert's output dimension + M_block_id = pid // N_BLOCK_COUNT # which expert are we in? (actually block, since there might be multiple blocks per expert) + N_block_id = pid % N_BLOCK_COUNT # which block in the out. dim are we in? + # Determine the block indices along the M and N dimensions for this program. + + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load( + padded_block_idxs + M_block_id + ) # Load the index of the starting token for this block + # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) # max tokens + E_idxs = tl.load( + sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E + ) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(sorted_scattered_idxs + M_block, mask=E_mask, other=0) + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + + L_BLOCK_K = tl.arange(0, ADA_BLCKS_PER_TILE_K) + L_BLOCK_N = tl.arange(0, ADA_BLCKS_PER_TILE_N) + additive_idx_blocks = (tl.arange(0, ADA_BLOCK))[:, None] * ADA_BLOCK + (tl.arange(0, ADA_BLOCK))[None, :] + L_blck_ptrs = ( + ada_layout + + L_BLOCK_K[:, None] * stride_layout_m + + L_BLOCK_N[None, :] * stride_layout_n + + N_block_id * ADA_BLCKS_PER_TILE_N + + E_idx * stride_layout_e + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(K, BLOCK_K) + for K_block_id in range(0, iters): + if NO_K_MASK: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + if NO_N_MASK or K_block_id < (iters - 1): + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + layout_tile = tl.load(L_blck_ptrs) # 2 x 8 + # BETTER TO RESAHPE MEMORY ADDRESSES, NOT THE LOADED DATA? + mask = layout_tile >= 0 + base_addresses = adaW + (layout_tile * (ADA_BLOCK * ADA_BLOCK)) + full_addresses = base_addresses[:,None,:,None] + additive_idx_blocks[None,:,None,:] + full_addresses = full_addresses.reshape(ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK) + mask = mask[:, None, :, None] * (tl.zeros((1, ADA_BLOCK, 1, ADA_BLOCK), dtype=ACC_TYPE) + 1.0) + mask = mask.reshape(ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK) > 0.0 + + adaW_tile = tl.load( + full_addresses, + mask=mask, + other=0.0, + ) + w = w + adaW_tile #.reshape(ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK) + L_blck_ptrs += ADA_BLCKS_PER_TILE_K * stride_layout_m + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc += tl.dot(x, w, out_dtype=ACC_TYPE) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +def scatter2scatter( + X, + W, + ada_weights, + ada_block, + ada_layout, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + padded_block_idxs, + x_grouped=False, + y_grouped=False, + out=None, +): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + # Pre-kernel setup + x_dim = X.size(-1) + y_dim = W.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + # O = torch.empty_like(ada_weights, device=X.device, dtype=ada_weights.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + O = out + + def grid(META): + grid_num = ( + padded_block_idxs.size(0) * triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid_num + + assert _scatter2scatter_configs()[0].kwargs["BLOCK_N"] % ada_block == 0 + assert _scatter2scatter_configs()[0].kwargs["BLOCK_K"] % ada_block == 0 + assert (ada_layout.size(1) * ada_block) % W.size(-1) == 0 + + M, K = X.size() + N = y_dim + E = (ada_layout.size(1) * ada_block) // W.size(-1) + ada_layout_stride_e = N // ada_block + # sorted_expert_idxs = sorted_expert_idxs.to(torch.int32) + # sorted_scattered_idxs = sorted_scattered_idxs.to(torch.int32) + # padded_block_idxs = padded_block_idxs.to(torch.int32) + + # with torch.cuda.device(X.device): + _scatter2scatter[grid]( + X, + X.stride(0), + X.stride(1), + W, + W.stride(0), + W.stride(1), + ada_weights, # n_exp x ada_block x ada_block + ada_layout, + ada_layout_stride_e, + ada_layout.stride(0), + ada_layout.stride(1), + O, + O.stride(0), + O.stride(1), + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT=k, + M=M, + K=K, + N=N, + E=E, + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=True, + x_grouped=x_grouped, + y_grouped=y_grouped, + ADA_BLOCK=ada_block, + ) + return O + +def _scatter2scatter_sp_configs(): + return [ + # triton.Config({"BLOCK_K": 128}, num_stages=4, num_warps=4), + ] + + +@triton.autotune( + configs=_scatter2scatter_sp_configs(), + key=["M", "N"], +) +@triton.jit +def _scatter2scatter_sp( + X_ptr, + stride_xm, + stride_xk, + gates, + adaW, # n_exp x ada_block x ada_block + adaW_stride_e, + adaW_stride_m, + adaW_stride_n, + base_act, + column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) + column_indices_t_offset, + offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t + offsets_t_offset, + block_offsets_t, # indices of blocks sorted by column + block_offsets_t_offset, + Y_ptr, + stride_ym, + stride_yn, + # OW, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT, + M, + N, + E, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_M: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv( + N, BLOCK_N + ) # is 2? numbe of blocks per expert's output dimension + M_block_id = pid // N_BLOCK_COUNT # which expert are we in? (actually block, since there might be multiple blocks per expert) + N_block_id = pid % N_BLOCK_COUNT # which block in the out. dim are we in? + # Determine the block indices along the M and N dimensions for this program. + + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(padded_block_idxs + M_block_id) # Load the index of the starting token for this block + # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) # max tokens + E_idxs = tl.load(sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(sorted_scattered_idxs + M_block, mask=E_mask, other=0) + M_in_idx = M_idx // FAN_OUT + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + start_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id) + end_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id + 1) + num_blocks_column = end_inx - start_inx + iters = num_blocks_column #tl.cdiv(num_blocks_column, tl.cdiv(BLOCK_K, ADA_BLOCK)) # n_blocks_column + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + gate = tl.load(gates + M_idx, mask=E_mask) + + if iters > 0: + # pointers to dense matrix + X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) #...16 + rbk = tl.arange(0, BLOCK_K) # ... 16 + W_blk_ptr = adaW + (rbk[:, None] * adaW_stride_m) + (rn[None, :] * adaW_stride_n) + (E_idx * adaW_stride_e) + BLOCK_SIZE = BLOCK_K * BLOCK_N + ak_block_incr = stride_xk * BLOCK_K + + # OW_block_ptr = OW + (rbk[:, None] * adaW_stride_m) + (rn[None, :] * adaW_stride_n) + (E_idx * adaW_stride_e) + + for K_block_id in range(0, iters): + X = X_blk_ptr + tl.load(column_indices_t + (E_idx * column_indices_t_offset) + start_inx + K_block_id) * ak_block_incr + + W = W_blk_ptr + tl.load(block_offsets_t + (E_idx * block_offsets_t_offset) + start_inx + K_block_id) * BLOCK_SIZE + # OWW = OW_block_ptr + tl.load(block_offsets_t + (E_idx * block_offsets_t_offset) + start_inx + K_block_id) * BLOCK_SIZE + + x = tl.load(X, mask=E_mask[:, None]) + w = tl.load(W, mask=N_mask[None, :]) + acc += tl.dot(x, w, out_dtype=ACC_TYPE) + + # tl.store(OWW, w) + + base_act_ptr = base_act + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + base_act = tl.load(base_act_ptr, mask=E_mask[:, None] & N_mask[None, :]) + acc *= gate[:, None] + acc += base_act + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + + + +def scatter2scatter_sparse( + X, + base_act, + ada_weights, + row_idxs, + col_idxs_t, + ada_block, + offsets_t, + block_offsets_t, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + padded_block_idxs, + gates, + out=None, +): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + assert X.is_contiguous() + assert base_act.is_contiguous() + assert ada_weights.is_contiguous() + assert row_idxs.is_contiguous() + assert col_idxs_t.is_contiguous() + assert offsets_t.is_contiguous() + assert block_offsets_t.is_contiguous() + assert sorted_expert_idxs.is_contiguous() + assert sorted_scattered_idxs.is_contiguous() + assert padded_block_idxs.is_contiguous() + assert gates.is_contiguous() + + + # Pre-kernel setup + x_dim = X.size(-1) + y_dim = base_act.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + # O = torch.empty_like(ada_weights, device=X.device, dtype=ada_weights.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + O = out + # OW = torch.empty_like(ada_weights, device=X.device, dtype=ada_weights.dtype) + def grid(META): + grid_num = ( + padded_block_idxs.size(0) * triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid_num + + M, K = X.size() + N = y_dim + E = ada_weights.size(0) + with torch.cuda.device(X.device): + _scatter2scatter_sp[grid]( + X, + X.stride(0), + X.stride(1), + gates, + ada_weights, # n_exp x ada_block x ada_block + ada_weights.stride(0), + ada_weights.stride(2), + ada_weights.stride(3), + base_act, + col_idxs_t, + col_idxs_t.stride(0), + offsets_t, # column offsets shapre is (E, N//ada_block + 1) + offsets_t.stride(0), + block_offsets_t, + block_offsets_t.stride(0), + O, + O.stride(0), + O.stride(1), + # OW, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT=k, + M=M, + N=N, + E=E, + BLOCK_M=BLOCK_M, + BLOCK_K=ada_block, + BLOCK_N=ada_block, + ACC_TYPE=tl.float32 + ) + return O + + +@triton.autotune( + configs=[ + triton.Config({"GROUP_M": 1, "BLOCK_M": 128}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 4, "BLOCK_M": 128}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 32, "BLOCK_M": 128}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 128, "BLOCK_M": 128}, num_stages=4, num_warps=4), + + triton.Config({"GROUP_M": 1, "BLOCK_M": 64}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 4, "BLOCK_M": 64}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 32, "BLOCK_M": 64}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 128, "BLOCK_M": 64}, num_stages=4, num_warps=4), + + + triton.Config({"GROUP_M": 1, "BLOCK_M": 256}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 4, "BLOCK_M": 256}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 32, "BLOCK_M": 256}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 128, "BLOCK_M": 256}, num_stages=4, num_warps=4), + ], + key=["M", "N", "E"], +) +@triton.jit +def _scatter2scatter_sp_optimized( + X_ptr, + stride_xm, + stride_xk, + gates, + adaW, # n_exp x ada_block x ada_block + adaW_stride_e, + adaW_stride_m, + adaW_stride_n, + base_act, + column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) + column_indices_t_offset, + offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t + offsets_t_offset, + block_offsets_t, # indices of blocks sorted by column + block_offsets_t_offset, + Y_ptr, + stride_ym, + stride_yn, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT: tl.constexpr, + M: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_M: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + M_block_id = pid_m # which expert are we in? (actually block, since there might be multiple blocks per expert) + N_block_id =pid_n # which block in the out. dim are we in? + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load(padded_block_idxs + M_block_id) # Load the index of the starting token for this block + # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) # max tokens + E_idxs = tl.load(sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(sorted_scattered_idxs + M_block, mask=E_mask, other=0) + M_in_idx = M_idx // FAN_OUT + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + start_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id) + end_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id + 1) + num_blocks_column = end_inx - start_inx + iters = num_blocks_column #tl.cdiv(num_blocks_column, tl.cdiv(BLOCK_K, ADA_BLOCK)) # n_blocks_column + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + gate = tl.load(gates + M_idx, mask=E_mask) + + if iters > 0: + # pointers to dense matrix + X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) #...16 + rbk = tl.arange(0, BLOCK_K) # ... 16 + W_blk_ptr = adaW + (rbk[:, None] * adaW_stride_m) + (rn[None, :] * adaW_stride_n) + (E_idx * adaW_stride_e) + BLOCK_SIZE = BLOCK_K * BLOCK_N + ak_block_incr = stride_xk * BLOCK_K + + for K_block_id in range(0, iters): + X = X_blk_ptr + tl.load(column_indices_t + (E_idx * column_indices_t_offset) + start_inx + K_block_id) * ak_block_incr + + W = W_blk_ptr + tl.load(block_offsets_t + (E_idx * block_offsets_t_offset) + start_inx + K_block_id) * BLOCK_SIZE + + x = tl.load(X, mask=E_mask[:, None]) + w = tl.load(W, mask=N_mask[None, :]) + acc += tl.dot(x, w, out_dtype=ACC_TYPE) + + base_act_ptr = base_act + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + base_act = tl.load(base_act_ptr, mask=E_mask[:, None] & N_mask[None, :]) + acc *= gate[:, None] + acc += base_act + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + # tl.atomic_add(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :], scope="cta") + + +def scatter2scatter_sparse_optimized( + X, + base_act, + ada_weights, + row_idxs, + col_idxs_t, + ada_block, + offsets_t, + block_offsets_t, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + padded_block_idxs, + gates, + out=None, + ): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + assert X.is_contiguous() + assert base_act.is_contiguous() + assert ada_weights.is_contiguous() + assert row_idxs.is_contiguous() + assert col_idxs_t.is_contiguous() + assert offsets_t.is_contiguous() + assert block_offsets_t.is_contiguous() + assert sorted_expert_idxs.is_contiguous() + assert sorted_scattered_idxs.is_contiguous() + assert padded_block_idxs.is_contiguous() + assert gates.is_contiguous() + + + # Pre-kernel setup + x_dim = X.size(-1) + y_dim = base_act.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + O = torch.zeros((L_scattered, y_dim), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + O = out + + def grid(META): + grid_num = ( + padded_block_idxs.size(0), triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid_num + + M, K = X.size() + N = y_dim + E = ada_weights.size(0) + with torch.cuda.device(X.device): + _scatter2scatter_sp_optimized[grid]( + X, + X.stride(0), + X.stride(1), + gates, + ada_weights, # n_exp x ada_block x ada_block + ada_weights.stride(0), + ada_weights.stride(2), + ada_weights.stride(3), + base_act, + col_idxs_t, + col_idxs_t.stride(0), + offsets_t, # column offsets shapre is (E, N//ada_block + 1) + offsets_t.stride(0), + block_offsets_t, + block_offsets_t.stride(0), + O, + O.stride(0), + O.stride(1), + # OW, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT=k, + M=M, + N=N, + E=E, + BLOCK_K=ada_block, + BLOCK_N=ada_block, + ACC_TYPE=tl.float32, + ) + return O \ No newline at end of file diff --git a/mttl/models/modifiers/spasity/stk/triton_kernels.py b/mttl/models/modifiers/spasity/stk/triton_kernels.py index 871d8fa58..91249c8d8 100644 --- a/mttl/models/modifiers/spasity/stk/triton_kernels.py +++ b/mttl/models/modifiers/spasity/stk/triton_kernels.py @@ -195,3 +195,542 @@ def sdd_spmerge( ada_layout.stride(1), ACC_TYPE=ACC_TYPE, ) + + +# this is from https://github.com/databricks/megablocks/blob/7b0337fa7278d224bf0c9be71c3a92c392fdd340/megablocks/backend/kernels.py#L107 + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out \ No newline at end of file From f95fd5dc36a0476393ea8d04879b8686e364189e Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 3 Oct 2024 15:50:04 -0400 Subject: [PATCH 07/24] nvm --- mttl/models/modifiers/spasity/stk/measure_time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mttl/models/modifiers/spasity/stk/measure_time.py b/mttl/models/modifiers/spasity/stk/measure_time.py index 4b2ead74a..280879426 100644 --- a/mttl/models/modifiers/spasity/stk/measure_time.py +++ b/mttl/models/modifiers/spasity/stk/measure_time.py @@ -177,4 +177,4 @@ def lora_merge(lora_a, lora_b, x, W_base, W_merge): lora_a = torch.randn(E, d, lora_rank, dtype=dtype).cuda().contiguous() lora_b = torch.randn(E, lora_rank, h, dtype=dtype).cuda().contiguous() func_lora = partial(lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights) - # benchmark_module("LoRA merge (our current vanila)", func_lora) \ No newline at end of file + benchmark_module("LoRA merge (our current vanila)", func_lora) \ No newline at end of file From 85859c995429f494ff53ef1a14dbb776240349f7 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 3 Oct 2024 15:58:51 -0400 Subject: [PATCH 08/24] nvm --- .../modifiers/spasity/stk/linear_ops_test_scatter.py | 4 ++-- mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py b/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py index f1cf80bb7..9f5a3672d 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py +++ b/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py @@ -45,8 +45,8 @@ def _dense(rows, cols, dtype, std=0.1): SC_MOE_TEST = { - # (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), - # (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), + (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), + (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), (8, 128, 256, 10, 2, 0.8, 16, torch.float32), } diff --git a/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py b/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py index 67dbf2adc..af4d10076 100644 --- a/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py +++ b/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py @@ -247,7 +247,7 @@ def _scatter2scatter_sp( adaW_stride_e, adaW_stride_m, adaW_stride_n, - base_act, + base_acts, column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) column_indices_t_offset, offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t @@ -326,7 +326,7 @@ def _scatter2scatter_sp( # tl.store(OWW, w) - base_act_ptr = base_act + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + base_act_ptr = base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn base_act = tl.load(base_act_ptr, mask=E_mask[:, None] & N_mask[None, :]) acc *= gate[:, None] acc += base_act @@ -454,7 +454,7 @@ def _scatter2scatter_sp_optimized( adaW_stride_e, adaW_stride_m, adaW_stride_n, - base_act, + base_acts, column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) column_indices_t_offset, offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t @@ -527,7 +527,7 @@ def _scatter2scatter_sp_optimized( w = tl.load(W, mask=N_mask[None, :]) acc += tl.dot(x, w, out_dtype=ACC_TYPE) - base_act_ptr = base_act + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + base_act_ptr = base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn base_act = tl.load(base_act_ptr, mask=E_mask[:, None] & N_mask[None, :]) acc *= gate[:, None] acc += base_act From c8d53617d085b0d194c22df116ced5f77fd55e18 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 3 Oct 2024 16:00:41 -0400 Subject: [PATCH 09/24] formating --- .../models/modifiers/spasity/stk/functions.py | 5 +- .../modifiers/spasity/stk/linear_ops.py | 1 + .../spasity/stk/linear_ops_test_megatron.py | 8 +- .../spasity/stk/linear_ops_test_scatter.py | 10 +- .../modifiers/spasity/stk/matrix_ops.py | 26 +- .../modifiers/spasity/stk/matrix_ops_test.py | 18 +- .../modifiers/spasity/stk/measure_time.py | 17 +- .../spasity/stk/scatter_moe_kernels.py | 243 ++++++++++++------ .../modifiers/spasity/stk/triton_kernels.py | 71 ++--- 9 files changed, 239 insertions(+), 160 deletions(-) diff --git a/mttl/models/modifiers/spasity/stk/functions.py b/mttl/models/modifiers/spasity/stk/functions.py index 1af296089..6e21fd435 100644 --- a/mttl/models/modifiers/spasity/stk/functions.py +++ b/mttl/models/modifiers/spasity/stk/functions.py @@ -124,7 +124,7 @@ def forward( sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, - gates + gates, ): output = scatter2scatter_sparse( @@ -151,7 +151,6 @@ def forward( parallel_linear = ParalleLinear.apply - class ParalleLinear2(torch.autograd.Function): @staticmethod @@ -194,4 +193,4 @@ def forward( return output -parallel_linear_optimized = ParalleLinear2.apply \ No newline at end of file +parallel_linear_optimized = ParalleLinear2.apply diff --git a/mttl/models/modifiers/spasity/stk/linear_ops.py b/mttl/models/modifiers/spasity/stk/linear_ops.py index 5a6bbbd69..e68ed333b 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops.py +++ b/mttl/models/modifiers/spasity/stk/linear_ops.py @@ -78,6 +78,7 @@ def scattergather_adamerge( ) return out + def scattergather_adamerge2( x, base_act, diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py b/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py index 15ba7e88b..3679cff12 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py +++ b/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py @@ -157,6 +157,7 @@ def testLinearOps_Sdd_wAdapters( (8, 128, 256, 10, 2, 0.8, 16, torch.float32), } + def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): output = torch.stack( [ @@ -219,8 +220,7 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): padded_block_idxs=padded_block_idxs, gates=k_weights, ) - - + out2 = functions.parallel_linear_optimized( x=X, base_act=base_act, @@ -236,9 +236,7 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): padded_block_idxs=padded_block_idxs, gates=k_weights, ) - - - + out_dumb = dumb_forward(base_act, X, k_weights, expert_idxs, adaps_dense) err_Y = torch.abs(out - out_dumb) tolerance = 1e-2 diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py b/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py index 9f5a3672d..97246075f 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py +++ b/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py @@ -14,6 +14,7 @@ # os.environ["TRITON_INTERPRET"] = "1" + def allclose(x, y, pct=0.25): mask = torch.isclose(x, y, rtol=5e-2) pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 @@ -43,13 +44,13 @@ def _dense(rows, cols, dtype, std=0.1): return out.to(cuda_device).requires_grad_(True) - SC_MOE_TEST = { (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), (8, 128, 256, 10, 2, 0.8, 16, torch.float32), } + def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): output = torch.stack( [ @@ -112,8 +113,7 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): # padded_block_idxs=padded_block_idxs, # gates=k_weights, # ) - - + out2 = linear_ops.scattergather_adamerge2( x=X, base_act=base_act, @@ -129,9 +129,7 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): padded_block_idxs=padded_block_idxs, gates=k_weights, ) - - - + out_dumb = dumb_forward(base_act, X, k_weights, expert_idxs, adaps_dense) err_Y = torch.abs(out2 - out_dumb) tolerance = 1e-2 diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops.py b/mttl/models/modifiers/spasity/stk/matrix_ops.py index 0163550a6..6f22cc3d6 100644 --- a/mttl/models/modifiers/spasity/stk/matrix_ops.py +++ b/mttl/models/modifiers/spasity/stk/matrix_ops.py @@ -6,11 +6,11 @@ from stk.matrix import Matrix -def _merge_adapters (adapters: List[Matrix]) -> Matrix: +def _merge_adapters(adapters: List[Matrix]) -> Matrix: """ Merges a list of adapters into a single adapter along the second dimention. Also changes the block size by padding blocks iwht 0s if necessary. - + """ col_indices_list = [adap.column_indices.to(torch.int32) for adap in adapters] # row_indices_list = [adap.row_indices for adap in adapters] @@ -23,7 +23,7 @@ def _merge_adapters (adapters: List[Matrix]) -> Matrix: ), "All adapters must have the same number of rows" block_size = adapters[0].blocking - + K, N = adapters[0].size() col_offset = N // block_size # assuming all have same number of cols n_adaps = len(adapters) @@ -61,29 +61,27 @@ def _merge_adapters (adapters: List[Matrix]) -> Matrix: return Matrix((K, n_adaps * N), data, row_indices, col_indices, offsets) + def change_block_size(M: Matrix, new_blk_size) -> Matrix: raise NotImplementedError("change_block_size is not implemented yet") - return - - - - + return -def merge_adapters(adapters: List[Matrix], blk_size = None) -> Matrix: +def merge_adapters(adapters: List[Matrix], blk_size=None) -> Matrix: """ Merges a list of adapters into a single adapter along the second dimention. Also changes the block size by padding blocks iwht 0s if necessary. - + """ - - out = _merge_adapters(adapters) # merges the adapters into a single Matrix() without changing the block size + + out = _merge_adapters( + adapters + ) # merges the adapters into a single Matrix() without changing the block size if blk_size is not None: - out = change_block_size(out, blk_size) + out = change_block_size(out, blk_size) return out - def create_ada_layout(matix: Matrix): """ Creates a binary tensor that identifies if block exists in the adapter matrix diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops_test.py b/mttl/models/modifiers/spasity/stk/matrix_ops_test.py index fd4da23db..01475df4a 100644 --- a/mttl/models/modifiers/spasity/stk/matrix_ops_test.py +++ b/mttl/models/modifiers/spasity/stk/matrix_ops_test.py @@ -19,21 +19,21 @@ def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): sparse.to(cuda_device).requires_grad_(True), ) -@parameterized.parameters( - (2, 8, 16, 0.5, 1), - (2, 8, 16, 0.5, 4) - ) -class MatrixOpsTest(parameterized.TestCase): + +@parameterized.parameters((2, 8, 16, 0.5, 1), (2, 8, 16, 0.5, 4)) +class MatrixOpsTest(parameterized.TestCase): def test_layout_creation(self, K, rows, cols, sparsity, blocking): - adaps = [_dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) for _ in range(K)] + adaps = [ + _dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) + for _ in range(K) + ] adaps_sparse = [adap[1] for adap in adaps] # adaps_dense = [adap[0] for adap in adaps] - + merged_adaps_matrix: Matrix = matrix_ops.merge_adapters(adaps_sparse) layout = matrix_ops.create_ada_layout(merged_adaps_matrix) assert layout.max() == merged_adaps_matrix.data.size(0) - 1 - -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/mttl/models/modifiers/spasity/stk/measure_time.py b/mttl/models/modifiers/spasity/stk/measure_time.py index 280879426..c769f68c8 100644 --- a/mttl/models/modifiers/spasity/stk/measure_time.py +++ b/mttl/models/modifiers/spasity/stk/measure_time.py @@ -58,6 +58,7 @@ def benchmark_module(name, function, runs=100): def calculate_lora_parameters(input_dim, output_dim, rank): return input_dim * rank + output_dim * rank + def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): lora_ranks = [] lora_rank = 1 @@ -70,8 +71,9 @@ def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): lora_ranks.append(lora_rank) return int(np.mean(lora_ranks)) + SC_MOE_TEST = { - # bs, d, h, E, k, sparsity, blocking, dtype + # bs, d, h, E, k, sparsity, blocking, dtype (1024, 2048, 8192, 20, 2, 0.995, 16, torch.float16), (1024, 2048, 8192, 20, 2, 0.9, 128, torch.float16), (1024, 2048, 8192, 100, 2, 0.995, 16, torch.float16), @@ -161,20 +163,21 @@ def call_with_baseact_and_idxs_computation(X, W, expert_idxs, function, **kwargs ) benchmark_module("BS kernel optimized", func) lora_rank = find_lora_hyperpaams(d, h, np.prod(ada_data.shape[1:])) - - + def lora_merge(lora_a, lora_b, x, W_base, W_merge): # LoRA does not profit from lower top-k in this vanila form # merge into 1 lora A = torch.einsum("be,edr->bdr", (W_merge, lora_a)) B = torch.einsum("be,erd->brd", (W_merge, lora_b)) # lora forward - partial_out = torch.einsum("bd,bdr->br", (x, A)) + partial_out = torch.einsum("bd,bdr->br", (x, A)) adapter_out = torch.einsum("br,brd->bd", (partial_out, B)) dense_out = x @ W_base return adapter_out + dense_out - + lora_a = torch.randn(E, d, lora_rank, dtype=dtype).cuda().contiguous() lora_b = torch.randn(E, lora_rank, h, dtype=dtype).cuda().contiguous() - func_lora = partial(lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights) - benchmark_module("LoRA merge (our current vanila)", func_lora) \ No newline at end of file + func_lora = partial( + lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights + ) + benchmark_module("LoRA merge (our current vanila)", func_lora) diff --git a/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py b/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py index af4d10076..ff197a659 100644 --- a/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py +++ b/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py @@ -58,15 +58,17 @@ def _scatter2scatter( NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr, ADA_BLOCK: tl.constexpr, - ADA_BLCKS_PER_TILE_K: tl.constexpr, # how many ada blocks in one tile in K direction - ADA_BLCKS_PER_TILE_N: tl.constexpr, # how many ada blocks in one tile in N direction + ADA_BLCKS_PER_TILE_K: tl.constexpr, # how many ada blocks in one tile in K direction + ADA_BLCKS_PER_TILE_N: tl.constexpr, # how many ada blocks in one tile in N direction ): pid = tl.program_id(axis=0) N_BLOCK_COUNT = tl.cdiv( N, BLOCK_N ) # is 2? numbe of blocks per expert's output dimension - M_block_id = pid // N_BLOCK_COUNT # which expert are we in? (actually block, since there might be multiple blocks per expert) + M_block_id = ( + pid // N_BLOCK_COUNT + ) # which expert are we in? (actually block, since there might be multiple blocks per expert) N_block_id = pid % N_BLOCK_COUNT # which block in the out. dim are we in? # Determine the block indices along the M and N dimensions for this program. @@ -99,10 +101,12 @@ def _scatter2scatter( X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn - + L_BLOCK_K = tl.arange(0, ADA_BLCKS_PER_TILE_K) L_BLOCK_N = tl.arange(0, ADA_BLCKS_PER_TILE_N) - additive_idx_blocks = (tl.arange(0, ADA_BLOCK))[:, None] * ADA_BLOCK + (tl.arange(0, ADA_BLOCK))[None, :] + additive_idx_blocks = (tl.arange(0, ADA_BLOCK))[:, None] * ADA_BLOCK + ( + tl.arange(0, ADA_BLOCK) + )[None, :] L_blck_ptrs = ( ada_layout + L_BLOCK_K[:, None] * stride_layout_m @@ -110,7 +114,7 @@ def _scatter2scatter( + N_block_id * ADA_BLCKS_PER_TILE_N + E_idx * stride_layout_e ) - + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) iters = tl.cdiv(K, BLOCK_K) for K_block_id in range(0, iters): @@ -129,17 +133,30 @@ def _scatter2scatter( # BETTER TO RESAHPE MEMORY ADDRESSES, NOT THE LOADED DATA? mask = layout_tile >= 0 base_addresses = adaW + (layout_tile * (ADA_BLOCK * ADA_BLOCK)) - full_addresses = base_addresses[:,None,:,None] + additive_idx_blocks[None,:,None,:] - full_addresses = full_addresses.reshape(ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK) - mask = mask[:, None, :, None] * (tl.zeros((1, ADA_BLOCK, 1, ADA_BLOCK), dtype=ACC_TYPE) + 1.0) - mask = mask.reshape(ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK) > 0.0 - + full_addresses = ( + base_addresses[:, None, :, None] + additive_idx_blocks[None, :, None, :] + ) + full_addresses = full_addresses.reshape( + ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK + ) + mask = mask[:, None, :, None] * ( + tl.zeros((1, ADA_BLOCK, 1, ADA_BLOCK), dtype=ACC_TYPE) + 1.0 + ) + mask = ( + mask.reshape( + ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK + ) + > 0.0 + ) + adaW_tile = tl.load( full_addresses, mask=mask, other=0.0, - ) - w = w + adaW_tile #.reshape(ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK) + ) + w = ( + w + adaW_tile + ) # .reshape(ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK) L_blck_ptrs += ADA_BLCKS_PER_TILE_K * stride_layout_m X_blk_ptrs += BLOCK_K * stride_xk W_blk_ptrs += BLOCK_K * stride_wk @@ -175,7 +192,7 @@ def scatter2scatter( else: assert out.size(0) == L_scattered and out.size(1) == y_dim O = out - + def grid(META): grid_num = ( padded_block_idxs.size(0) * triton.cdiv(META["N"], META["BLOCK_N"]), @@ -193,8 +210,8 @@ def grid(META): # sorted_expert_idxs = sorted_expert_idxs.to(torch.int32) # sorted_scattered_idxs = sorted_scattered_idxs.to(torch.int32) # padded_block_idxs = padded_block_idxs.to(torch.int32) - - # with torch.cuda.device(X.device): + + # with torch.cuda.device(X.device): _scatter2scatter[grid]( X, X.stride(0), @@ -227,6 +244,7 @@ def grid(META): ) return O + def _scatter2scatter_sp_configs(): return [ # triton.Config({"BLOCK_K": 128}, num_stages=4, num_warps=4), @@ -248,11 +266,11 @@ def _scatter2scatter_sp( adaW_stride_m, adaW_stride_n, base_acts, - column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) + column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) column_indices_t_offset, - offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t + offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t offsets_t_offset, - block_offsets_t, # indices of blocks sorted by column + block_offsets_t, # indices of blocks sorted by column block_offsets_t_offset, Y_ptr, stride_ym, @@ -275,15 +293,21 @@ def _scatter2scatter_sp( N_BLOCK_COUNT = tl.cdiv( N, BLOCK_N ) # is 2? numbe of blocks per expert's output dimension - M_block_id = pid // N_BLOCK_COUNT # which expert are we in? (actually block, since there might be multiple blocks per expert) + M_block_id = ( + pid // N_BLOCK_COUNT + ) # which expert are we in? (actually block, since there might be multiple blocks per expert) N_block_id = pid % N_BLOCK_COUNT # which block in the out. dim are we in? # Determine the block indices along the M and N dimensions for this program. M_range = tl.arange(0, BLOCK_M) - block_start_idx = tl.load(padded_block_idxs + M_block_id) # Load the index of the starting token for this block + block_start_idx = tl.load( + padded_block_idxs + M_block_id + ) # Load the index of the starting token for this block # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) # max tokens - E_idxs = tl.load(sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens + E_idxs = tl.load( + sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E + ) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens E_idx = tl.min(E_idxs) E_mask = E_idxs == E_idx M_idx = tl.load(sorted_scattered_idxs + M_block, mask=E_mask, other=0) @@ -296,47 +320,70 @@ def _scatter2scatter_sp( start_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id) end_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id + 1) num_blocks_column = end_inx - start_inx - iters = num_blocks_column #tl.cdiv(num_blocks_column, tl.cdiv(BLOCK_K, ADA_BLOCK)) # n_blocks_column + iters = num_blocks_column # tl.cdiv(num_blocks_column, tl.cdiv(BLOCK_K, ADA_BLOCK)) # n_blocks_column acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - + gate = tl.load(gates + M_idx, mask=E_mask) - + if iters > 0: # pointers to dense matrix X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk - + # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) #...16 - rbk = tl.arange(0, BLOCK_K) # ... 16 - W_blk_ptr = adaW + (rbk[:, None] * adaW_stride_m) + (rn[None, :] * adaW_stride_n) + (E_idx * adaW_stride_e) + rn = tl.arange(0, BLOCK_N) # ...16 + rbk = tl.arange(0, BLOCK_K) # ... 16 + W_blk_ptr = ( + adaW + + (rbk[:, None] * adaW_stride_m) + + (rn[None, :] * adaW_stride_n) + + (E_idx * adaW_stride_e) + ) BLOCK_SIZE = BLOCK_K * BLOCK_N ak_block_incr = stride_xk * BLOCK_K - + # OW_block_ptr = OW + (rbk[:, None] * adaW_stride_m) + (rn[None, :] * adaW_stride_n) + (E_idx * adaW_stride_e) - for K_block_id in range(0, iters): - X = X_blk_ptr + tl.load(column_indices_t + (E_idx * column_indices_t_offset) + start_inx + K_block_id) * ak_block_incr - - W = W_blk_ptr + tl.load(block_offsets_t + (E_idx * block_offsets_t_offset) + start_inx + K_block_id) * BLOCK_SIZE + for K_block_id in range(0, iters): + X = ( + X_blk_ptr + + tl.load( + column_indices_t + + (E_idx * column_indices_t_offset) + + start_inx + + K_block_id + ) + * ak_block_incr + ) + + W = ( + W_blk_ptr + + tl.load( + block_offsets_t + + (E_idx * block_offsets_t_offset) + + start_inx + + K_block_id + ) + * BLOCK_SIZE + ) # OWW = OW_block_ptr + tl.load(block_offsets_t + (E_idx * block_offsets_t_offset) + start_inx + K_block_id) * BLOCK_SIZE - + x = tl.load(X, mask=E_mask[:, None]) - w = tl.load(W, mask=N_mask[None, :]) + w = tl.load(W, mask=N_mask[None, :]) acc += tl.dot(x, w, out_dtype=ACC_TYPE) - + # tl.store(OWW, w) - - base_act_ptr = base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + + base_act_ptr = ( + base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + ) base_act = tl.load(base_act_ptr, mask=E_mask[:, None] & N_mask[None, :]) acc *= gate[:, None] acc += base_act - + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) - - def scatter2scatter_sparse( X, base_act, @@ -366,8 +413,7 @@ def scatter2scatter_sparse( assert sorted_scattered_idxs.is_contiguous() assert padded_block_idxs.is_contiguous() assert gates.is_contiguous() - - + # Pre-kernel setup x_dim = X.size(-1) y_dim = base_act.size(-1) @@ -378,6 +424,7 @@ def scatter2scatter_sparse( else: assert out.size(0) == L_scattered and out.size(1) == y_dim O = out + # OW = torch.empty_like(ada_weights, device=X.device, dtype=ada_weights.dtype) def grid(META): grid_num = ( @@ -388,7 +435,7 @@ def grid(META): M, K = X.size() N = y_dim E = ada_weights.size(0) - with torch.cuda.device(X.device): + with torch.cuda.device(X.device): _scatter2scatter_sp[grid]( X, X.stride(0), @@ -401,7 +448,7 @@ def grid(META): base_act, col_idxs_t, col_idxs_t.stride(0), - offsets_t, # column offsets shapre is (E, N//ada_block + 1) + offsets_t, # column offsets shapre is (E, N//ada_block + 1) offsets_t.stride(0), block_offsets_t, block_offsets_t.stride(0), @@ -419,7 +466,7 @@ def grid(META): BLOCK_M=BLOCK_M, BLOCK_K=ada_block, BLOCK_N=ada_block, - ACC_TYPE=tl.float32 + ACC_TYPE=tl.float32, ) return O @@ -430,13 +477,10 @@ def grid(META): triton.Config({"GROUP_M": 4, "BLOCK_M": 128}, num_stages=4, num_warps=4), triton.Config({"GROUP_M": 32, "BLOCK_M": 128}, num_stages=4, num_warps=4), triton.Config({"GROUP_M": 128, "BLOCK_M": 128}, num_stages=4, num_warps=4), - triton.Config({"GROUP_M": 1, "BLOCK_M": 64}, num_stages=4, num_warps=4), triton.Config({"GROUP_M": 4, "BLOCK_M": 64}, num_stages=4, num_warps=4), triton.Config({"GROUP_M": 32, "BLOCK_M": 64}, num_stages=4, num_warps=4), triton.Config({"GROUP_M": 128, "BLOCK_M": 64}, num_stages=4, num_warps=4), - - triton.Config({"GROUP_M": 1, "BLOCK_M": 256}, num_stages=4, num_warps=4), triton.Config({"GROUP_M": 4, "BLOCK_M": 256}, num_stages=4, num_warps=4), triton.Config({"GROUP_M": 32, "BLOCK_M": 256}, num_stages=4, num_warps=4), @@ -455,11 +499,11 @@ def _scatter2scatter_sp_optimized( adaW_stride_m, adaW_stride_n, base_acts, - column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) + column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) column_indices_t_offset, - offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t + offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t offsets_t_offset, - block_offsets_t, # indices of blocks sorted by column + block_offsets_t, # indices of blocks sorted by column block_offsets_t_offset, Y_ptr, stride_ym, @@ -478,19 +522,23 @@ def _scatter2scatter_sp_optimized( GROUP_M: tl.constexpr, ): pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) + pid_n = tl.program_id(axis=1) num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - M_block_id = pid_m # which expert are we in? (actually block, since there might be multiple blocks per expert) - N_block_id =pid_n # which block in the out. dim are we in? + + M_block_id = pid_m # which expert are we in? (actually block, since there might be multiple blocks per expert) + N_block_id = pid_n # which block in the out. dim are we in? M_range = tl.arange(0, BLOCK_M) - block_start_idx = tl.load(padded_block_idxs + M_block_id) # Load the index of the starting token for this block + block_start_idx = tl.load( + padded_block_idxs + M_block_id + ) # Load the index of the starting token for this block # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) # max tokens - E_idxs = tl.load(sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens + E_idxs = tl.load( + sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E + ) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens E_idx = tl.min(E_idxs) E_mask = E_idxs == E_idx M_idx = tl.load(sorted_scattered_idxs + M_block, mask=E_mask, other=0) @@ -503,40 +551,65 @@ def _scatter2scatter_sp_optimized( start_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id) end_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id + 1) num_blocks_column = end_inx - start_inx - iters = num_blocks_column #tl.cdiv(num_blocks_column, tl.cdiv(BLOCK_K, ADA_BLOCK)) # n_blocks_column + iters = num_blocks_column # tl.cdiv(num_blocks_column, tl.cdiv(BLOCK_K, ADA_BLOCK)) # n_blocks_column acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - + gate = tl.load(gates + M_idx, mask=E_mask) - + if iters > 0: # pointers to dense matrix - X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) #...16 - rbk = tl.arange(0, BLOCK_K) # ... 16 - W_blk_ptr = adaW + (rbk[:, None] * adaW_stride_m) + (rn[None, :] * adaW_stride_n) + (E_idx * adaW_stride_e) + rn = tl.arange(0, BLOCK_N) # ...16 + rbk = tl.arange(0, BLOCK_K) # ... 16 + W_blk_ptr = ( + adaW + + (rbk[:, None] * adaW_stride_m) + + (rn[None, :] * adaW_stride_n) + + (E_idx * adaW_stride_e) + ) BLOCK_SIZE = BLOCK_K * BLOCK_N ak_block_incr = stride_xk * BLOCK_K - for K_block_id in range(0, iters): - X = X_blk_ptr + tl.load(column_indices_t + (E_idx * column_indices_t_offset) + start_inx + K_block_id) * ak_block_incr - - W = W_blk_ptr + tl.load(block_offsets_t + (E_idx * block_offsets_t_offset) + start_inx + K_block_id) * BLOCK_SIZE - + for K_block_id in range(0, iters): + X = ( + X_blk_ptr + + tl.load( + column_indices_t + + (E_idx * column_indices_t_offset) + + start_inx + + K_block_id + ) + * ak_block_incr + ) + + W = ( + W_blk_ptr + + tl.load( + block_offsets_t + + (E_idx * block_offsets_t_offset) + + start_inx + + K_block_id + ) + * BLOCK_SIZE + ) + x = tl.load(X, mask=E_mask[:, None]) - w = tl.load(W, mask=N_mask[None, :]) + w = tl.load(W, mask=N_mask[None, :]) acc += tl.dot(x, w, out_dtype=ACC_TYPE) - - base_act_ptr = base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + + base_act_ptr = ( + base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + ) base_act = tl.load(base_act_ptr, mask=E_mask[:, None] & N_mask[None, :]) acc *= gate[:, None] acc += base_act - + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) # tl.atomic_add(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :], scope="cta") - - + + def scatter2scatter_sparse_optimized( X, base_act, @@ -552,7 +625,7 @@ def scatter2scatter_sparse_optimized( padded_block_idxs, gates, out=None, - ): +): assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) assert sorted_scattered_idxs.size(0) == X.size(0) * k assert X.is_contiguous() @@ -566,8 +639,7 @@ def scatter2scatter_sparse_optimized( assert sorted_scattered_idxs.is_contiguous() assert padded_block_idxs.is_contiguous() assert gates.is_contiguous() - - + # Pre-kernel setup x_dim = X.size(-1) y_dim = base_act.size(-1) @@ -577,17 +649,18 @@ def scatter2scatter_sparse_optimized( else: assert out.size(0) == L_scattered and out.size(1) == y_dim O = out - + def grid(META): grid_num = ( - padded_block_idxs.size(0), triton.cdiv(META["N"], META["BLOCK_N"]), + padded_block_idxs.size(0), + triton.cdiv(META["N"], META["BLOCK_N"]), ) return grid_num M, K = X.size() N = y_dim E = ada_weights.size(0) - with torch.cuda.device(X.device): + with torch.cuda.device(X.device): _scatter2scatter_sp_optimized[grid]( X, X.stride(0), @@ -600,7 +673,7 @@ def grid(META): base_act, col_idxs_t, col_idxs_t.stride(0), - offsets_t, # column offsets shapre is (E, N//ada_block + 1) + offsets_t, # column offsets shapre is (E, N//ada_block + 1) offsets_t.stride(0), block_offsets_t, block_offsets_t.stride(0), @@ -619,4 +692,4 @@ def grid(META): BLOCK_N=ada_block, ACC_TYPE=tl.float32, ) - return O \ No newline at end of file + return O diff --git a/mttl/models/modifiers/spasity/stk/triton_kernels.py b/mttl/models/modifiers/spasity/stk/triton_kernels.py index 91249c8d8..7cba8b178 100644 --- a/mttl/models/modifiers/spasity/stk/triton_kernels.py +++ b/mttl/models/modifiers/spasity/stk/triton_kernels.py @@ -36,7 +36,11 @@ def _validate_matmul_dims(M: int, K: int, N: int): num_warps=TritonConfig.NUM_WARPS, ), ], - key=["M", "N", "K"], # uses these keys to decide wether to re-evaluate the choise of best config + key=[ + "M", + "N", + "K", + ], # uses these keys to decide wether to re-evaluate the choise of best config ) @triton.jit # this is understood def _sdd_adamerge( @@ -166,7 +170,7 @@ def sdd_spmerge( ACC_TYPE = tl.float32 else: raise ValueError(f"Unsupported dtype: {out.dtype}") - + # launch kernel nnz_blocks = len(row_indices) grid = lambda META: (nnz_blocks,) # this just alunches 61 threadblocks @@ -199,9 +203,10 @@ def sdd_spmerge( # this is from https://github.com/databricks/megablocks/blob/7b0337fa7278d224bf0c9be71c3a92c392fdd340/megablocks/backend/kernels.py#L107 + def assert_is_tensor(x, ndim): if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + raise ValueError(f"Expected {ndim}-tensor but got {x.ndim}-tensor") def assert_is_matrix(x): @@ -210,12 +215,14 @@ def assert_is_matrix(x): def assert_is_vector(x): if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + raise ValueError(f"Expected 1-tensor but got {x.ndim}-tensor") def assert_equal(a, b): if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + raise ValueError( + f"Expected dimensions to be equal but got {a} and {b}.", + ) # a: (tokens, hidden_size), real. @@ -226,13 +233,13 @@ def assert_equal(a, b): # padded_bins: (num_experts), integer. @triton.autotune( configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), ], - key=['NUM_COLUMNS'], + key=["NUM_COLUMNS"], ) @triton.jit def _padded_copy( @@ -409,13 +416,13 @@ def scatter(x, indices, bin_ids, weights, bins, top_k): # padded_bins: (num_experts), integer. @triton.autotune( configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), ], - key=['NUM_COLUMNS'], + key=["NUM_COLUMNS"], ) @triton.jit def _padded_copy_wgrad( @@ -507,13 +514,13 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): # bins: (num_experts), integer. @triton.autotune( configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), ], - key=['NUM_COLUMNS'], + key=["NUM_COLUMNS"], ) @triton.jit def _binned_copy( @@ -593,7 +600,9 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): assert_equal(weights.shape[0], x.shape[0] * top_k) num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + out = torch.zeros( + (num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device + ) _binned_copy[(num_experts, expert_capacity)]( x, @@ -649,13 +658,13 @@ def binned_scatter(x, indices, weights, bins, top_k): # bins: (num_experts), integer. @triton.autotune( configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), ], - key=['NUM_COLUMNS'], + key=["NUM_COLUMNS"], ) @triton.jit def _binned_copy_wgrad( @@ -733,4 +742,4 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k): NUM_COLUMNS=hidden_size, TOP_K=top_k, ) - return out \ No newline at end of file + return out From 0772fd53b160020fac6c403bd9654b6d05cfb8d1 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 09:36:59 -0400 Subject: [PATCH 10/24] refactor --- .../spasity/sparse_utils/bsr_moe_benchmark.py | 6 +- .../stk_matrix_test.py} | 15 +- .../stk_matrix_utils.py} | 17 + .../spasity/{stk => spb_moe}/__init__.py | 0 .../measure_time.py => spb_moe/benchmark.py} | 12 +- .../bsr_adapter_moe_test.py} | 28 +- .../spasity/{stk => spb_moe}/functions.py | 4 +- .../spasity/{stk => spb_moe}/linear_ops.py | 4 +- .../triton_kernels.py} | 0 .../spasity/stk/linear_ops_test_megatron.py | 248 ------ .../modifiers/spasity/stk/triton_kernels.py | 745 ------------------ 11 files changed, 36 insertions(+), 1043 deletions(-) rename mttl/models/modifiers/spasity/{stk/matrix_ops_test.py => sparse_utils/stk_matrix_test.py} (57%) rename mttl/models/modifiers/spasity/{stk/matrix_ops.py => sparse_utils/stk_matrix_utils.py} (84%) rename mttl/models/modifiers/spasity/{stk => spb_moe}/__init__.py (100%) rename mttl/models/modifiers/spasity/{stk/measure_time.py => spb_moe/benchmark.py} (94%) rename mttl/models/modifiers/spasity/{stk/linear_ops_test_scatter.py => spb_moe/bsr_adapter_moe_test.py} (82%) rename mttl/models/modifiers/spasity/{stk => spb_moe}/functions.py (97%) rename mttl/models/modifiers/spasity/{stk => spb_moe}/linear_ops.py (98%) rename mttl/models/modifiers/spasity/{stk/scatter_moe_kernels.py => spb_moe/triton_kernels.py} (100%) delete mode 100644 mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py delete mode 100644 mttl/models/modifiers/spasity/stk/triton_kernels.py diff --git a/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py b/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py index da6f79d8d..8bc0417cd 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py +++ b/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py @@ -30,7 +30,7 @@ padded_gather, padded_scatter, ) -from mttl.models.modifiers.spasity.stk import linear_ops, matrix_ops +from mttl.models.modifiers.spasity.spb_moe import _matrix_ops, linear_ops from mttl.models.utils import model_loader_helper, transfer_batch_to_device device = "cuda" @@ -162,7 +162,7 @@ def create_block_diagonal_matrix(bs_m, bs_n, n_blocks): loras = create_adapter_set(adapter_config_lora, layer, K) sparse_modules = create_adapter_set(adapter_config_bs, layer, K) sparse_mtxs = sparsemodules_to_stkmatrix_list(sparse_modules) -adaptersMatrix: Matrix = matrix_ops.merge_adapters(sparse_mtxs).to(device) +adaptersMatrix: Matrix = _matrix_ops.merge_adapters(sparse_mtxs).to(device) W_mege = W_mege.to(dtype=loras[0].lora_a.dtype) top_k_indices = torch.topk(torch.abs(W_mege), top_k, dim=-1).indices @@ -174,7 +174,7 @@ def create_block_diagonal_matrix(bs_m, bs_n, n_blocks): positions_in_expert_padded, padding_mask, ) = padded_gather(x, top_k_indices, K) -layout = matrix_ops.create_ada_layout(adaptersMatrix).to(device) +layout = _matrix_ops.create_ada_layout(adaptersMatrix).to(device) out_blck_size = x.shape[1] x = x.reshape(-1, in_d).contiguous() diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops_test.py b/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_test.py similarity index 57% rename from mttl/models/modifiers/spasity/stk/matrix_ops_test.py rename to mttl/models/modifiers/spasity/sparse_utils/stk_matrix_test.py index 01475df4a..750cc761e 100644 --- a/mttl/models/modifiers/spasity/stk/matrix_ops_test.py +++ b/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_test.py @@ -6,25 +6,14 @@ from absl.testing import parameterized from stk.matrix import Matrix -from mttl.models.modifiers.spasity.stk import matrix_ops - - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return ( - dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True), - ) +from mttl.models.modifiers.spasity.sparse_utils import stk_matrix_utils as matrix_ops @parameterized.parameters((2, 8, 16, 0.5, 1), (2, 8, 16, 0.5, 4)) class MatrixOpsTest(parameterized.TestCase): def test_layout_creation(self, K, rows, cols, sparsity, blocking): adaps = [ - _dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) + matrix_ops._dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) for _ in range(K) ] adaps_sparse = [adap[1] for adap in adaps] diff --git a/mttl/models/modifiers/spasity/stk/matrix_ops.py b/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py similarity index 84% rename from mttl/models/modifiers/spasity/stk/matrix_ops.py rename to mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py index 6f22cc3d6..db2b0d66e 100644 --- a/mttl/models/modifiers/spasity/stk/matrix_ops.py +++ b/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py @@ -6,6 +6,23 @@ from stk.matrix import Matrix +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return ( + dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True), + ) + def _merge_adapters(adapters: List[Matrix]) -> Matrix: """ Merges a list of adapters into a single adapter along the second dimention. diff --git a/mttl/models/modifiers/spasity/stk/__init__.py b/mttl/models/modifiers/spasity/spb_moe/__init__.py similarity index 100% rename from mttl/models/modifiers/spasity/stk/__init__.py rename to mttl/models/modifiers/spasity/spb_moe/__init__.py diff --git a/mttl/models/modifiers/spasity/stk/measure_time.py b/mttl/models/modifiers/spasity/spb_moe/benchmark.py similarity index 94% rename from mttl/models/modifiers/spasity/stk/measure_time.py rename to mttl/models/modifiers/spasity/spb_moe/benchmark.py index c769f68c8..41ae883a1 100644 --- a/mttl/models/modifiers/spasity/stk/measure_time.py +++ b/mttl/models/modifiers/spasity/spb_moe/benchmark.py @@ -9,11 +9,9 @@ from pytorch_lightning import seed_everything from stk.matrix import Matrix -from mttl.models.modifiers.spasity.stk import functions, linear_ops, matrix_ops -from mttl.models.modifiers.spasity.stk.linear_ops_test_scatter import ( - _dense_and_sparse, - dumb_forward, -) +from mttl.models.modifiers.spasity.sparse_utils import stk_matrix_utils as matrix_ops +from mttl.models.modifiers.spasity.spb_moe import linear_ops +from mttl.models.modifiers.spasity.spb_moe.bsr_adapter_moe_test import dumb_forward def benchmark_module(name, function, runs=100): @@ -96,7 +94,7 @@ def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() - adaps = [_dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] + adaps = [matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] adaps_sparse = [adap[1] for adap in adaps] adaps_dense = [adap[0] for adap in adaps] ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) @@ -151,7 +149,7 @@ def call_with_baseact_and_idxs_computation(X, W, expert_idxs, function, **kwargs X=X, W=W, expert_idxs=expert_idxs, - function=linear_ops.scattergather_adamerge2, + function=linear_ops.scattergather_adamerge_opt, k=k, ada_weights=ada_data, row_idxs=row_idxs, diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py b/mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py similarity index 82% rename from mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py rename to mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py index 97246075f..d09725dee 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops_test_scatter.py +++ b/mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py @@ -10,9 +10,10 @@ from pytorch_lightning import seed_everything from stk.matrix import Matrix -from mttl.models.modifiers.spasity.stk import functions, linear_ops, matrix_ops +from mttl.models.modifiers.spasity.sparse_utils import stk_matrix_utils as matrix_ops +from mttl.models.modifiers.spasity.spb_moe import linear_ops -# os.environ["TRITON_INTERPRET"] = "1" +# os.environ["TRITON_INTERPRET"] = "1" # def allclose(x, y, pct=0.25): @@ -26,31 +27,12 @@ def allclose(x, y, pct=0.25): blocksize = 16 - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return ( - dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True), - ) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - SC_MOE_TEST = { (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), (8, 128, 256, 10, 2, 0.8, 16, torch.float32), } - def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): output = torch.stack( [ @@ -75,7 +57,7 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() - adaps = [_dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] + adaps = [matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] adaps_sparse = [adap[1] for adap in adaps] adaps_dense = [adap[0] for adap in adaps] ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) @@ -114,7 +96,7 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): # gates=k_weights, # ) - out2 = linear_ops.scattergather_adamerge2( + out2 = linear_ops.scattergather_adamerge_opt( x=X, base_act=base_act, k=k, diff --git a/mttl/models/modifiers/spasity/stk/functions.py b/mttl/models/modifiers/spasity/spb_moe/functions.py similarity index 97% rename from mttl/models/modifiers/spasity/stk/functions.py rename to mttl/models/modifiers/spasity/spb_moe/functions.py index 6e21fd435..5806575ba 100644 --- a/mttl/models/modifiers/spasity/stk/functions.py +++ b/mttl/models/modifiers/spasity/spb_moe/functions.py @@ -4,8 +4,8 @@ from stk.backend.autocast import custom_bwd, custom_fwd from stk.matrix import Matrix -import mttl.models.modifiers.spasity.stk.triton_kernels as backend -from mttl.models.modifiers.spasity.stk.scatter_moe_kernels import ( +import mttl.models.modifiers.spasity.spb_moe._triton_kernels as backend +from mttl.models.modifiers.spasity.spb_moe.triton_kernels import ( scatter2scatter_sparse, scatter2scatter_sparse_optimized, ) diff --git a/mttl/models/modifiers/spasity/stk/linear_ops.py b/mttl/models/modifiers/spasity/spb_moe/linear_ops.py similarity index 98% rename from mttl/models/modifiers/spasity/stk/linear_ops.py rename to mttl/models/modifiers/spasity/spb_moe/linear_ops.py index e68ed333b..536b27932 100644 --- a/mttl/models/modifiers/spasity/stk/linear_ops.py +++ b/mttl/models/modifiers/spasity/spb_moe/linear_ops.py @@ -1,7 +1,7 @@ import torch from stk.matrix import Matrix -from mttl.models.modifiers.spasity.stk import functions +from mttl.models.modifiers.spasity.spb_moe import functions def sdd_adamerge(a, b, out_topo: Matrix, out_adaps: Matrix, layout): @@ -79,7 +79,7 @@ def scattergather_adamerge( return out -def scattergather_adamerge2( +def scattergather_adamerge_opt( x, base_act, k, diff --git a/mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py b/mttl/models/modifiers/spasity/spb_moe/triton_kernels.py similarity index 100% rename from mttl/models/modifiers/spasity/stk/scatter_moe_kernels.py rename to mttl/models/modifiers/spasity/spb_moe/triton_kernels.py diff --git a/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py b/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py deleted file mode 100644 index 3679cff12..000000000 --- a/mttl/models/modifiers/spasity/stk/linear_ops_test_megatron.py +++ /dev/null @@ -1,248 +0,0 @@ -import itertools -import os -import unittest - -import numpy as np -import stk -import torch -import torch.nn.functional as F -from absl.testing import parameterized -from pytorch_lightning import seed_everything -from stk.matrix import Matrix - -from mttl.models.modifiers.spasity.stk import functions, linear_ops, matrix_ops - -# os.environ["TRITON_INTERPRET"] = "1" - - -# os.environ["TRITON_INTERPRET"] = "1" - - -# os.environ["TRITON_INTERPRET"] = "1" - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -blocksize = 16 -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.8), - (128, 128, 64, 0.8), - (128, 128, 128, 0.0), - # (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - # (512, 128, 128, 0.0), - # (128, 128, 512, 0.0), - # (1024, 512, 512, 0.0), - # (1024, 512, 512, 0.5), - # (1024, 512, 512, 0.75), - # (512, 512, 1024, 0.0), - # (512, 512, 1024, 0.5), - # (512, 512, 1024, 0.75), - # (1024, 1024, 1024, 0.0), - # (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - # (False, True), - # (True, False), - # (True, True), -) - -_DTYPE = (torch.float16, torch.bfloat16) - - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [ - (*size, *trans, blocksize, dtype) for (size, trans, dtype) in testcases - ] - return testcases - - -_LINEAR_OP_TESTS = _generate_testcases() - - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return ( - dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True), - ) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _mmm_with_adapters(a, W_base, topo, adapters): - b = W_base.repeat(1, len(adapters)) - adaps_as_dense = [stk.ops.to_dense(adap) for adap in adapters] - b = b + torch.cat(adaps_as_dense, dim=1) - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - def testLinearOps_Sdd_wAdapters( - self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype - ): - if trans_a or trans_b: - return - # Construct the operands. - # This tests the use-case where we have base weights and a bunch of adapters. We perform SDD of input x with base weights, but block-ssparse adapters are merged into the base weights first. - - a_shape = (m, k) - a, acp = _dense_2x(*a_shape, dtype) - - n_adaps = 10 - adapters = [ - _dense_and_sparse(*(k, n), sparsity, blocking, dtype)[1] - for _ in range(n_adaps) - ] - # merge all adapters into a single sparse Matrix() - adaps: Matrix = matrix_ops.merge_adapters(adapters) - - out_shape = (m, n * n_adaps) - _, out_topo = _dense_and_sparse(*out_shape, sparsity, blocking, dtype) - # create a mapping from out_topo to adaps, indicating whether each out_topo bvlock needs to be merged with an adapter block, and if so which one - layout = matrix_ops.create_ada_layout(adaps) - - w_shape = (k, n) - W_base, W_basecp = _dense_2x(*w_shape, dtype) - # Execute the matmul. - out = linear_ops.sdd_adamerge(a, W_base, out_topo, adaps, layout) - expected_out = _mmm_with_adapters(acp, W_basecp, out_topo, adapters) - - adapters_as_dense = torch.cat( - [stk.ops.to_dense(adap) for adap in adapters], dim=1 - ) - adaps_as_dense = stk.ops.to_dense(adaps) - assert ( - torch.sum(adapters_as_dense != adaps_as_dense) == 0 - ), "adapters and adaps should be the same" - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - -SC_MOE_TEST = { - (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), - (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), - (8, 128, 256, 10, 2, 0.8, 16, torch.float32), -} - - -def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): - output = torch.stack( - [ - sum( - base_act[i] - + (expert_p[i, j] * torch.matmul(x[i], (adaps[expert_idxs[i, j]]))) - for j in range(expert_idxs.size(1)) - ) - for i in range(expert_idxs.size(0)) - ], - dim=0, - ) - return output - - -@parameterized.parameters(*SC_MOE_TEST) -class ScatteredMoETest(parameterized.TestCase): - def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): - torch.manual_seed(42) - # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") - logits = torch.randn(bs, E, dtype=dtype) - weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) - X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() - W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() - adaps = [_dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] - adaps_sparse = [adap[1] for adap in adaps] - adaps_dense = [adap[0] for adap in adaps] - ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) - row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) - col_idxs_t = torch.stack( - [adap.column_indices_t for adap in adaps_sparse], dim=0 - ) - offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) - block_offsets_t = torch.stack( - [adap.block_offsets_t for adap in adaps_sparse], dim=0 - ) - - k_weights, expert_idxs = torch.topk(weights, k) - sorted_expert_idxs, sorted_scattered_idxs = linear_ops.flatten_and_sort( - expert_idxs - ) - padded_block_idxs, expert_offsets = linear_ops.padded_block_indices( - sorted_expert_idxs, E - ) - - base_act = torch.matmul(X, W) - - out = functions.parallel_linear( - x=X, - base_act=base_act, - k=k, - ada_weights=ada_data, - row_idxs=row_idxs, - col_idxs=col_idxs_t, - offsets=offsets_t, - block_offsets_t=block_offsets_t, - ada_block_size=blocking, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, - gates=k_weights, - ) - - out2 = functions.parallel_linear_optimized( - x=X, - base_act=base_act, - k=k, - ada_weights=ada_data, - row_idxs=row_idxs, - col_idxs=col_idxs_t, - offsets=offsets_t, - block_offsets_t=block_offsets_t, - ada_block_size=blocking, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, - gates=k_weights, - ) - - out_dumb = dumb_forward(base_act, X, k_weights, expert_idxs, adaps_dense) - err_Y = torch.abs(out - out_dumb) - tolerance = 1e-2 - # print(err_Y.max()) - assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() - - -if __name__ == "__main__": - unittest.main() diff --git a/mttl/models/modifiers/spasity/stk/triton_kernels.py b/mttl/models/modifiers/spasity/stk/triton_kernels.py deleted file mode 100644 index 7cba8b178..000000000 --- a/mttl/models/modifiers/spasity/stk/triton_kernels.py +++ /dev/null @@ -1,745 +0,0 @@ -from dataclasses import dataclass - -import torch -import triton -import triton.language as tl - - -@dataclass -class TritonConfig: - BLOCK_M: int = 16 # 128 - BLOCK_N: int = 16 # 128 - BLOCK_K: int = 16 # 32 - # BLOCK_SIZE: int = 128 # block size in the output matrix? - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config( - { - "BLOCK_M": TritonConfig.BLOCK_M, - "BLOCK_N": TritonConfig.BLOCK_N, - "BLOCK_K": TritonConfig.BLOCK_K, - # "BLOCK_SIZE": TritonConfig.BLOCK_SIZE, - }, - num_stages=TritonConfig.NUM_STAGES, - num_warps=TritonConfig.NUM_WARPS, - ), - ], - key=[ - "M", - "N", - "K", - ], # uses these keys to decide wether to re-evaluate the choise of best config -) -@triton.jit # this is understood -def _sdd_adamerge( - A, - B, - S, - OUT, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - row_indices, - column_indices, - layout, - stride_layout_m, - stride_layout_n, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, -): - # matrix multiplication - pid = tl.program_id(0) # in triton only control thread blocks - pid_m = tl.load( - row_indices + pid - ) # row index of the block in the output matrix that is being computed by this thread block - pid_n = tl.load( - column_indices + pid - ) # column index of the block in the output matrix that is being computed by this thread block - rm = pid_m * BLOCK_M + tl.arange( - 0, BLOCK_M - ) # the actual row indices in the output matrix - rn = pid_n * BLOCK_N + tl.arange( - 0, BLOCK_N - ) # the actual column indices in the output matrix - ram = tl.max_contiguous( - tl.multiple_of(rm % M, BLOCK_M), BLOCK_M - ) # optimizes memory throughput by ensuring that the memory accesses are contiguous - rbn = tl.max_contiguous( - tl.multiple_of(rn % N, BLOCK_N), BLOCK_N - ) # optimizes memory throughput by ensuring that the memory accesses are contiguous - rk = tl.arange(0, BLOCK_K) # innialize inner dimention range for the current block - BLOCK_ELEMENTS = BLOCK_M * BLOCK_N # BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - # pointers - A = A + ( - ram[:, None] * stride_am + rk[None, :] * stride_ak - ) # BLOCK_M x BLOCK_K pointes to the dense matrix A for loading - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - s_blck = tl.load(layout + k * stride_layout_m + pid_n * stride_layout_n) - mask = s_blck >= 0 - s_blck = tl.where(mask, s_blck, 0) - s_ptr = ( - S - + s_blck * BLOCK_ELEMENTS - + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - ) - s = tl.load(s_ptr) - s = tl.where(mask[None, None], s, tl.zeros_like(s)) - b = b + s - acc += tl.dot(a, b) # this should be using tensor cores on A100 - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - # Store to sparse matrix - acc = acc.to(OUT.dtype.element_ty) - # remember, in OUT we only store the non-zero elements, so no need to map it to dense matrix - OUT = ( - OUT + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - ) - tl.store(OUT, acc, mask=True) - - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - - -def row_indices(shape, data, offsets, column_indices, out): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows,)](offsets, out) - - -def sdd_spmerge( - x, - base_weights, - shape, - out, - row_indices, - column_indices, - ada_data, - ada_layout, # -): - # E is the number of experts - # ada_data is (E x n_blocks_per_e) x block_size x block_size - # base_weights is dense matrix of shape (K, (expert_out_dim x E) - # ada_row_indices is (E x n_blocks_per_e) - # ada_column_indices is (E x n_blocks_per_e) - # base_weights.shape[1 = expert out dim. - - assert x.shape[1] == base_weights.shape[0], "incompatible dimensions" - M, K = x.shape - _, N = base_weights.shape - assert ( - shape[1] & N == 0 - ), "RHS out dimension must be divisible by base weights output dim." - E = shape[1] // N - block_size = ada_data.shape[1] - - _validate_matmul_dims(M, K, N) - - if out.dtype in [torch.float16, torch.bfloat16, torch.float32]: - ACC_TYPE = tl.float32 - else: - raise ValueError(f"Unsupported dtype: {out.dtype}") - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) # this just alunches 61 threadblocks - - stride_am, stride_ak = x.stride(0), x.stride(1) - stride_bk, stride_bn = base_weights.stride(0), base_weights.stride(1) - - _sdd_adamerge[grid]( - x, - base_weights, - ada_data, - out, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - out.stride(1), - out.stride(2), - row_indices, - column_indices, - ada_layout, - ada_layout.stride(0), - ada_layout.stride(1), - ACC_TYPE=ACC_TYPE, - ) - - -# this is from https://github.com/databricks/megablocks/blob/7b0337fa7278d224bf0c9be71c3a92c392fdd340/megablocks/backend/kernels.py#L107 - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f"Expected {ndim}-tensor but got {x.ndim}-tensor") - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f"Expected 1-tensor but got {x.ndim}-tensor") - - -def assert_equal(a, b): - if a != b: - raise ValueError( - f"Expected dimensions to be equal but got {a} and {b}.", - ) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({"BLOCK_X": 64}, num_warps=2), - triton.Config({"BLOCK_X": 128}, num_warps=2), - triton.Config({"BLOCK_X": 256}, num_warps=2), - triton.Config({"BLOCK_X": 128}, num_warps=4), - triton.Config({"BLOCK_X": 256}, num_warps=4), - ], - key=["NUM_COLUMNS"], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({"BLOCK_X": 64}, num_warps=2), - triton.Config({"BLOCK_X": 128}, num_warps=2), - triton.Config({"BLOCK_X": 256}, num_warps=2), - triton.Config({"BLOCK_X": 128}, num_warps=4), - triton.Config({"BLOCK_X": 256}, num_warps=4), - ], - key=["NUM_COLUMNS"], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({"BLOCK_X": 64}, num_warps=2), - triton.Config({"BLOCK_X": 128}, num_warps=2), - triton.Config({"BLOCK_X": 256}, num_warps=2), - triton.Config({"BLOCK_X": 128}, num_warps=4), - triton.Config({"BLOCK_X": 256}, num_warps=4), - ], - key=["NUM_COLUMNS"], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros( - (num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device - ) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({"BLOCK_X": 64}, num_warps=2), - triton.Config({"BLOCK_X": 128}, num_warps=2), - triton.Config({"BLOCK_X": 256}, num_warps=2), - triton.Config({"BLOCK_X": 128}, num_warps=4), - triton.Config({"BLOCK_X": 256}, num_warps=4), - ], - key=["NUM_COLUMNS"], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out From 5061165b28e3355cf6e089d826a1faa27c2b5b35 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 11:42:32 -0400 Subject: [PATCH 11/24] format --- .../spasity/sparse_utils/bsr_moe_benchmark.py | 8 +- .../spasity/sparse_utils/stk_matrix_utils.py | 2 +- .../modifiers/spasity/sparse_utils/utils.py | 55 ++++++++---- .../modifiers/spasity/spb_moe/benchmark.py | 6 +- .../spasity/spb_moe/bsr_adapter_moe_test.py | 6 +- .../spasity/spb_moe/triton_kernels.py | 83 ++++++++++--------- 6 files changed, 102 insertions(+), 58 deletions(-) diff --git a/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py b/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py index 8bc0417cd..bdb6e003e 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py +++ b/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py @@ -137,8 +137,12 @@ def create_block_diagonal_matrix(bs_m, bs_n, n_blocks): nb_m_pb = bs_m // block_size nb_n_pb = bs_n // block_size - col_indices_1blk = torch.arange(nb_n_pb, device=device, dtype=torch.int32).repeat(nb_m_pb) - row_indices_1blk = torch.arange(nb_m_pb, device=device, dtype=torch.int32).repeat_interleave(nb_n_pb) + col_indices_1blk = torch.arange(nb_n_pb, device=device, dtype=torch.int32).repeat( + nb_m_pb + ) + row_indices_1blk = torch.arange( + nb_m_pb, device=device, dtype=torch.int32 + ).repeat_interleave(nb_n_pb) offsets = torch.arange(0, Mb * nb_n_pb + nb_n_pb, nb_n_pb, device=device) col_idx = torch.cat([col_indices_1blk + i * nb_n_pb for i in range(n_blocks)]) diff --git a/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py b/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py index db2b0d66e..9f7b7e9f4 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py +++ b/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py @@ -12,7 +12,6 @@ def _dense(rows, cols, dtype, std=0.1): return out.to(cuda_device).requires_grad_(True) - def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): mask = stk.random.dense_mask(rows, cols, sparsity, blocking) dense = (torch.randn(rows, cols) * std * mask).type(dtype) @@ -23,6 +22,7 @@ def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): sparse.to(cuda_device).requires_grad_(True), ) + def _merge_adapters(adapters: List[Matrix]) -> Matrix: """ Merges a list of adapters into a single adapter along the second dimention. diff --git a/mttl/models/modifiers/spasity/sparse_utils/utils.py b/mttl/models/modifiers/spasity/sparse_utils/utils.py index 52756766c..1de7d7e6b 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/utils.py +++ b/mttl/models/modifiers/spasity/sparse_utils/utils.py @@ -548,8 +548,12 @@ def padded_gather(x, indices, E, block_size=16): indices_flat = indices.view(-1) # [batch_size * seq_len * top_k] # Step 2: Expand x to match indices - x_flat_expanded = x_flat.unsqueeze(1).expand(-1, top_k, -1) # [batch_size * seq_len, top_k, d_model] - x_expert = x_flat_expanded.reshape(-1, d_model) # [batch_size * seq_len * top_k, d_model] + x_flat_expanded = x_flat.unsqueeze(1).expand( + -1, top_k, -1 + ) # [batch_size * seq_len, top_k, d_model] + x_expert = x_flat_expanded.reshape( + -1, d_model + ) # [batch_size * seq_len * top_k, d_model] # Step 3: Sort indices and x_expert to group tokens by expert indices_expert, sort_order = indices_flat.sort() @@ -559,13 +563,17 @@ def padded_gather(x, indices, E, block_size=16): num_tokens_per_expert = torch.bincount(indices_expert, minlength=E) # [E] # Step 5: Compute padded number of tokens per expert - padded_num_tokens_per_expert = ((num_tokens_per_expert + block_size - 1) // block_size) * block_size # [E] + padded_num_tokens_per_expert = ( + (num_tokens_per_expert + block_size - 1) // block_size + ) * block_size # [E] max_tokens_per_expert = padded_num_tokens_per_expert.max().item() # Step 6: Compute positions within each expert def compute_positions_in_group(indices_expert): unique_indices, counts = indices_expert.unique_consecutive(return_counts=True) - positions_in_expert = torch.cat([torch.arange(count, device=indices_expert.device) for count in counts]) + positions_in_expert = torch.cat( + [torch.arange(count, device=indices_expert.device) for count in counts] + ) return positions_in_expert positions_in_expert = compute_positions_in_group(indices_expert) @@ -586,15 +594,17 @@ def compute_positions_in_group(indices_expert): padding = padding_needed[e].item() # Get the indices and positions for the current expert - indices_e = indices_expert[current_idx:current_idx+count] - positions_e = positions_in_expert[current_idx:current_idx+count] - x_expert_e = x_expert_sorted[current_idx:current_idx+count] + indices_e = indices_expert[current_idx : current_idx + count] + positions_e = positions_in_expert[current_idx : current_idx + count] + x_expert_e = x_expert_sorted[current_idx : current_idx + count] # Append original tokens indices_expert_padded.append(indices_e) positions_in_expert_padded.append(positions_e) x_expert_padded.append(x_expert_e) - padding_mask.append(torch.ones(count, dtype=torch.bool, device=indices_expert.device)) + padding_mask.append( + torch.ones(count, dtype=torch.bool, device=indices_expert.device) + ) # If padding is needed, duplicate the last token 'padding' times if padding > 0: @@ -606,7 +616,9 @@ def compute_positions_in_group(indices_expert): indices_expert_padded.append(indices_e_pad) positions_in_expert_padded.append(positions_e_pad) x_expert_padded.append(x_expert_e_pad) - padding_mask.append(torch.zeros(padding, dtype=torch.bool, device=indices_expert.device)) + padding_mask.append( + torch.zeros(padding, dtype=torch.bool, device=indices_expert.device) + ) current_idx += count @@ -623,10 +635,19 @@ def compute_positions_in_group(indices_expert): output[indices_expert_padded, positions_in_expert_padded] = x_expert_padded # Return additional information for padded_scatter - return output, num_tokens_per_expert, sort_order, indices_expert_padded, positions_in_expert_padded, padding_mask + return ( + output, + num_tokens_per_expert, + sort_order, + indices_expert_padded, + positions_in_expert_padded, + padding_mask, + ) -def padded_scatter(x, num_tokens_per_expert, sort_order, batch_size, seq_len, top_k, d_model): +def padded_scatter( + x, num_tokens_per_expert, sort_order, batch_size, seq_len, top_k, d_model +): """ Un-permute tokens back to their original positions. @@ -649,8 +670,12 @@ def padded_scatter(x, num_tokens_per_expert, sort_order, batch_size, seq_len, to x_flat = x.view(-1, d_model) # [E * max_tokens_per_expert, d_model] # Step 2: Build indices for valid tokens - expert_indices = torch.repeat_interleave(torch.arange(E, device=device), num_tokens_per_expert) - positions_in_expert = torch.cat([torch.arange(n, device=device) for n in num_tokens_per_expert]) + expert_indices = torch.repeat_interleave( + torch.arange(E, device=device), num_tokens_per_expert + ) + positions_in_expert = torch.cat( + [torch.arange(n, device=device) for n in num_tokens_per_expert] + ) valid_positions = expert_indices * max_tokens_per_expert + positions_in_expert # Step 3: Select valid tokens @@ -660,7 +685,9 @@ def padded_scatter(x, num_tokens_per_expert, sort_order, batch_size, seq_len, to x_expert_sorted = x_valid # Step 5: Reconstruct x_expert using inverse of sort_order - x_expert = torch.empty((batch_size * seq_len * top_k, d_model), device=device, dtype=x.dtype) + x_expert = torch.empty( + (batch_size * seq_len * top_k, d_model), device=device, dtype=x.dtype + ) x_expert[sort_order] = x_expert_sorted # Step 6: Reshape to [batch_size, seq_len, top_k, d_model] diff --git a/mttl/models/modifiers/spasity/spb_moe/benchmark.py b/mttl/models/modifiers/spasity/spb_moe/benchmark.py index 41ae883a1..0c301b8c2 100644 --- a/mttl/models/modifiers/spasity/spb_moe/benchmark.py +++ b/mttl/models/modifiers/spasity/spb_moe/benchmark.py @@ -94,7 +94,9 @@ def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() - adaps = [matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] + adaps = [ + matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E) + ] adaps_sparse = [adap[1] for adap in adaps] adaps_dense = [adap[0] for adap in adaps] ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) @@ -178,4 +180,4 @@ def lora_merge(lora_a, lora_b, x, W_base, W_merge): func_lora = partial( lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights ) - benchmark_module("LoRA merge (our current vanila)", func_lora) + # benchmark_module("LoRA merge (our current vanila)", func_lora) diff --git a/mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py b/mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py index d09725dee..90877c82a 100644 --- a/mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py +++ b/mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py @@ -33,6 +33,7 @@ def allclose(x, y, pct=0.25): (8, 128, 256, 10, 2, 0.8, 16, torch.float32), } + def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): output = torch.stack( [ @@ -57,7 +58,10 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() - adaps = [matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E)] + adaps = [ + matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) + for _ in range(E) + ] adaps_sparse = [adap[1] for adap in adaps] adaps_dense = [adap[0] for adap in adaps] ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) diff --git a/mttl/models/modifiers/spasity/spb_moe/triton_kernels.py b/mttl/models/modifiers/spasity/spb_moe/triton_kernels.py index ff197a659..d757f00a5 100644 --- a/mttl/models/modifiers/spasity/spb_moe/triton_kernels.py +++ b/mttl/models/modifiers/spasity/spb_moe/triton_kernels.py @@ -520,6 +520,7 @@ def _scatter2scatter_sp_optimized( BLOCK_M: tl.constexpr, ACC_TYPE: tl.constexpr, GROUP_M: tl.constexpr, + MAX_K_ITERS: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) @@ -556,47 +557,51 @@ def _scatter2scatter_sp_optimized( gate = tl.load(gates + M_idx, mask=E_mask) - if iters > 0: - # pointers to dense matrix - X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) # ...16 - rbk = tl.arange(0, BLOCK_K) # ... 16 - W_blk_ptr = ( - adaW - + (rbk[:, None] * adaW_stride_m) - + (rn[None, :] * adaW_stride_n) - + (E_idx * adaW_stride_e) - ) - BLOCK_SIZE = BLOCK_K * BLOCK_N - ak_block_incr = stride_xk * BLOCK_K - - for K_block_id in range(0, iters): - X = ( - X_blk_ptr - + tl.load( - column_indices_t - + (E_idx * column_indices_t_offset) - + start_inx - + K_block_id - ) - * ak_block_incr + # pointers to dense matrix + X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + W_blk_ptr = ( + adaW + + (rbk[:, None] * adaW_stride_m) + + (rn[None, :] * adaW_stride_n) + + (E_idx * adaW_stride_e) + ) + BLOCK_SIZE = BLOCK_K * BLOCK_N + ak_block_incr = stride_xk * BLOCK_K + + for K_block_id in tl.range(0, MAX_K_ITERS): + valid = K_block_id < iters + X = ( + X_blk_ptr + + tl.load( + column_indices_t + + (E_idx * column_indices_t_offset) + + start_inx + + K_block_id, + mask=valid, + other=0, ) + * ak_block_incr + ) - W = ( - W_blk_ptr - + tl.load( - block_offsets_t - + (E_idx * block_offsets_t_offset) - + start_inx - + K_block_id - ) - * BLOCK_SIZE + W = ( + W_blk_ptr + + tl.load( + block_offsets_t + + (E_idx * block_offsets_t_offset) + + start_inx + + K_block_id, + mask=valid, + other=0, ) + * BLOCK_SIZE + ) - x = tl.load(X, mask=E_mask[:, None]) - w = tl.load(W, mask=N_mask[None, :]) - acc += tl.dot(x, w, out_dtype=ACC_TYPE) + x = tl.load(X, mask=valid & E_mask[:, None], other=0.0) + w = tl.load(W, mask=valid & N_mask[None, :], other=0.0) + acc += tl.dot(x, w, out_dtype=ACC_TYPE) base_act_ptr = ( base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn @@ -607,7 +612,7 @@ def _scatter2scatter_sp_optimized( Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) - # tl.atomic_add(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :], scope="cta") + # tl.atomic_add(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :], scope="cta") # <- could be used to fuse the merging op into the kernel, but it snot working for soem reason def scatter2scatter_sparse_optimized( @@ -660,6 +665,7 @@ def grid(META): M, K = X.size() N = y_dim E = ada_weights.size(0) + MAX_ITERS = (K + ada_block - 1) // ada_block with torch.cuda.device(X.device): _scatter2scatter_sp_optimized[grid]( X, @@ -691,5 +697,6 @@ def grid(META): BLOCK_K=ada_block, BLOCK_N=ada_block, ACC_TYPE=tl.float32, + MAX_K_ITERS=MAX_ITERS, ) return O From dd6b71a0754483237cbc50068f50eeb5ae258aa1 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 11:57:31 -0400 Subject: [PATCH 12/24] rename --- mttl/models/modifiers/__init__.py | 2 +- mttl/models/modifiers/sparsity/__init__.py | 1 + .../{spasity => sparsity}/sparse_mask.py | 2 +- .../sparse_utils/bsr_ddsloop_benchmark.py | 2 +- .../sparse_utils/bsr_moe_benchmark.py | 8 ++++---- .../sparse_utils/csr_add_vs_scatter_add.py | 2 +- .../sparse_utils/stk_matrix_test.py | 2 +- .../sparse_utils/stk_matrix_utils.py | 0 .../{spasity => sparsity}/sparse_utils/utils.py | 2 +- .../{spasity => sparsity}/spb_moe/__init__.py | 0 .../{spasity => sparsity}/spb_moe/benchmark.py | 16 ++++++++-------- .../spb_moe/bsr_adapter_moe_test.py | 10 +++++----- .../{spasity => sparsity}/spb_moe/functions.py | 4 ++-- .../linear_ops.py => sparsity/spb_moe/ops.py} | 2 +- .../spb_moe/triton_kernels.py | 0 mttl/models/modifiers/spasity/__init__.py | 1 - tests/test_sparse_masks.py | 2 +- 17 files changed, 28 insertions(+), 28 deletions(-) create mode 100644 mttl/models/modifiers/sparsity/__init__.py rename mttl/models/modifiers/{spasity => sparsity}/sparse_mask.py (99%) rename mttl/models/modifiers/{spasity => sparsity}/sparse_utils/bsr_ddsloop_benchmark.py (99%) rename mttl/models/modifiers/{spasity => sparsity}/sparse_utils/bsr_moe_benchmark.py (96%) rename mttl/models/modifiers/{spasity => sparsity}/sparse_utils/csr_add_vs_scatter_add.py (97%) rename mttl/models/modifiers/{spasity => sparsity}/sparse_utils/stk_matrix_test.py (90%) rename mttl/models/modifiers/{spasity => sparsity}/sparse_utils/stk_matrix_utils.py (100%) rename mttl/models/modifiers/{spasity => sparsity}/sparse_utils/utils.py (99%) rename mttl/models/modifiers/{spasity => sparsity}/spb_moe/__init__.py (100%) rename mttl/models/modifiers/{spasity => sparsity}/spb_moe/benchmark.py (91%) rename mttl/models/modifiers/{spasity => sparsity}/spb_moe/bsr_adapter_moe_test.py (91%) rename mttl/models/modifiers/{spasity => sparsity}/spb_moe/functions.py (97%) rename mttl/models/modifiers/{spasity/spb_moe/linear_ops.py => sparsity/spb_moe/ops.py} (98%) rename mttl/models/modifiers/{spasity => sparsity}/spb_moe/triton_kernels.py (100%) delete mode 100644 mttl/models/modifiers/spasity/__init__.py diff --git a/mttl/models/modifiers/__init__.py b/mttl/models/modifiers/__init__.py index 63de57288..47f54131a 100644 --- a/mttl/models/modifiers/__init__.py +++ b/mttl/models/modifiers/__init__.py @@ -6,4 +6,4 @@ import mttl.models.modifiers.lora # noqa: F401 import mttl.models.modifiers.mlp # noqa: F401 import mttl.models.modifiers.prompt_tuning # noqa: F401 -import mttl.models.modifiers.spasity.sparse_mask # noqa: F401 +import mttl.models.modifiers.sparsity.sparse_mask # noqa: F401 diff --git a/mttl/models/modifiers/sparsity/__init__.py b/mttl/models/modifiers/sparsity/__init__.py new file mode 100644 index 000000000..28883ff10 --- /dev/null +++ b/mttl/models/modifiers/sparsity/__init__.py @@ -0,0 +1 @@ +from mttl.models.modifiers.sparsity.sparse_mask import * diff --git a/mttl/models/modifiers/spasity/sparse_mask.py b/mttl/models/modifiers/sparsity/sparse_mask.py similarity index 99% rename from mttl/models/modifiers/spasity/sparse_mask.py rename to mttl/models/modifiers/sparsity/sparse_mask.py index 53ba15ce5..b5624aabf 100644 --- a/mttl/models/modifiers/spasity/sparse_mask.py +++ b/mttl/models/modifiers/sparsity/sparse_mask.py @@ -10,7 +10,7 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig, ModifyMixin -from mttl.models.modifiers.spasity.sparse_utils.utils import ( +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( BlcokSparseLinearFunction_SP_ADD, BlcokSparseLinearFunction_SP_SCATTER, LinearWithSparseDelta, diff --git a/mttl/models/modifiers/spasity/sparse_utils/bsr_ddsloop_benchmark.py b/mttl/models/modifiers/sparsity/sparse_utils/bsr_ddsloop_benchmark.py similarity index 99% rename from mttl/models/modifiers/spasity/sparse_utils/bsr_ddsloop_benchmark.py rename to mttl/models/modifiers/sparsity/sparse_utils/bsr_ddsloop_benchmark.py index 1da9d168e..927700d3c 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/bsr_ddsloop_benchmark.py +++ b/mttl/models/modifiers/sparsity/sparse_utils/bsr_ddsloop_benchmark.py @@ -18,7 +18,7 @@ from mttl.models.modifiers import modify_transformer from mttl.models.modifiers.base import Modifier from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig -from mttl.models.modifiers.spasity.sparse_mask import ( +from mttl.models.modifiers.sparsity.sparse_mask import ( MaskedLinear, ScatteredSparseLinearModule, SparseLinearModule, diff --git a/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py b/mttl/models/modifiers/sparsity/sparse_utils/bsr_moe_benchmark.py similarity index 96% rename from mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py rename to mttl/models/modifiers/sparsity/sparse_utils/bsr_moe_benchmark.py index bdb6e003e..d3f209fa7 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/bsr_moe_benchmark.py +++ b/mttl/models/modifiers/sparsity/sparse_utils/bsr_moe_benchmark.py @@ -19,18 +19,18 @@ from mttl.models.modifiers import modify_transformer from mttl.models.modifiers.base import Modifier from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig -from mttl.models.modifiers.spasity.sparse_mask import ( +from mttl.models.modifiers.sparsity.sparse_mask import ( MaskedLinear, ScatteredSparseLinearModule, SparseLinearModule, SparseMaskAdapter, SparseMaskConfig, ) -from mttl.models.modifiers.spasity.sparse_utils.utils import ( +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( padded_gather, padded_scatter, ) -from mttl.models.modifiers.spasity.spb_moe import _matrix_ops, linear_ops +from mttl.models.modifiers.sparsity.spb_moe import _matrix_ops, ops from mttl.models.utils import model_loader_helper, transfer_batch_to_device device = "cuda" @@ -184,7 +184,7 @@ def create_block_diagonal_matrix(bs_m, bs_n, n_blocks): x = x.reshape(-1, in_d).contiguous() out_topology = create_block_diagonal_matrix(out_blck_size, out_d, K) W_base = layer.weight.T.to(dtype=dtype) -output = linear_ops.sdd_adamerge(x, W_base, out_topology, adaptersMatrix, layout) +output = ops.sdd_adamerge(x, W_base, out_topology, adaptersMatrix, layout) print(output.shape) # create output topoly diff --git a/mttl/models/modifiers/spasity/sparse_utils/csr_add_vs_scatter_add.py b/mttl/models/modifiers/sparsity/sparse_utils/csr_add_vs_scatter_add.py similarity index 97% rename from mttl/models/modifiers/spasity/sparse_utils/csr_add_vs_scatter_add.py rename to mttl/models/modifiers/sparsity/sparse_utils/csr_add_vs_scatter_add.py index 1a104404e..93b1c8d07 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/csr_add_vs_scatter_add.py +++ b/mttl/models/modifiers/sparsity/sparse_utils/csr_add_vs_scatter_add.py @@ -8,7 +8,7 @@ from triton.ops.blocksparse import matmul from mttl.models.modifiers.sparse_utils.utils import init_sparse_weights -from mttl.models.modifiers.spasity.sparse_mask import SparseMaskConfig, SparseWeights +from mttl.models.modifiers.sparsity.sparse_mask import SparseMaskConfig, SparseWeights n_blocks = 8 BLOCK_SIZE = 128 diff --git a/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_test.py b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py similarity index 90% rename from mttl/models/modifiers/spasity/sparse_utils/stk_matrix_test.py rename to mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py index 750cc761e..3a82258f5 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_test.py +++ b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py @@ -6,7 +6,7 @@ from absl.testing import parameterized from stk.matrix import Matrix -from mttl.models.modifiers.spasity.sparse_utils import stk_matrix_utils as matrix_ops +from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops @parameterized.parameters((2, 8, 16, 0.5, 1), (2, 8, 16, 0.5, 4)) diff --git a/mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_utils.py similarity index 100% rename from mttl/models/modifiers/spasity/sparse_utils/stk_matrix_utils.py rename to mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_utils.py diff --git a/mttl/models/modifiers/spasity/sparse_utils/utils.py b/mttl/models/modifiers/sparsity/sparse_utils/utils.py similarity index 99% rename from mttl/models/modifiers/spasity/sparse_utils/utils.py rename to mttl/models/modifiers/sparsity/sparse_utils/utils.py index 1de7d7e6b..aa34fe800 100644 --- a/mttl/models/modifiers/spasity/sparse_utils/utils.py +++ b/mttl/models/modifiers/sparsity/sparse_utils/utils.py @@ -283,7 +283,7 @@ def init_sparse_weights(sps_type, keep_ratio, shape, block_size=None): def make_sparse_model_during_training(module, batch): - from mttl.models.modifiers.spasity.sparse_mask import ( + from mttl.models.modifiers.sparsity.sparse_mask import ( SparseMaskAdapter as SparseMaskModule, ) diff --git a/mttl/models/modifiers/spasity/spb_moe/__init__.py b/mttl/models/modifiers/sparsity/spb_moe/__init__.py similarity index 100% rename from mttl/models/modifiers/spasity/spb_moe/__init__.py rename to mttl/models/modifiers/sparsity/spb_moe/__init__.py diff --git a/mttl/models/modifiers/spasity/spb_moe/benchmark.py b/mttl/models/modifiers/sparsity/spb_moe/benchmark.py similarity index 91% rename from mttl/models/modifiers/spasity/spb_moe/benchmark.py rename to mttl/models/modifiers/sparsity/spb_moe/benchmark.py index 0c301b8c2..bfe09faf3 100644 --- a/mttl/models/modifiers/spasity/spb_moe/benchmark.py +++ b/mttl/models/modifiers/sparsity/spb_moe/benchmark.py @@ -9,9 +9,9 @@ from pytorch_lightning import seed_everything from stk.matrix import Matrix -from mttl.models.modifiers.spasity.sparse_utils import stk_matrix_utils as matrix_ops -from mttl.models.modifiers.spasity.spb_moe import linear_ops -from mttl.models.modifiers.spasity.spb_moe.bsr_adapter_moe_test import dumb_forward +from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops +from mttl.models.modifiers.sparsity.spb_moe import ops +from mttl.models.modifiers.sparsity.spb_moe.bsr_adapter_moe_test import dumb_forward def benchmark_module(name, function, runs=100): @@ -111,10 +111,10 @@ def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): def call_with_baseact_and_idxs_computation(X, W, expert_idxs, function, **kwargs): base_act = torch.matmul(X, W) - sorted_expert_idxs, sorted_scattered_idxs = linear_ops.flatten_and_sort( + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( expert_idxs ) - padded_block_idxs, expert_offsets = linear_ops.padded_block_indices( + padded_block_idxs, expert_offsets = ops.padded_block_indices( sorted_expert_idxs, E ) return function( @@ -132,7 +132,7 @@ def call_with_baseact_and_idxs_computation(X, W, expert_idxs, function, **kwargs X=X, W=W, expert_idxs=expert_idxs, - function=linear_ops.scattergather_adamerge, + function=ops.scattergather_adamerge, k=k, ada_weights=ada_data, row_idxs=row_idxs, @@ -151,7 +151,7 @@ def call_with_baseact_and_idxs_computation(X, W, expert_idxs, function, **kwargs X=X, W=W, expert_idxs=expert_idxs, - function=linear_ops.scattergather_adamerge_opt, + function=ops.scattergather_adamerge_opt, k=k, ada_weights=ada_data, row_idxs=row_idxs, @@ -180,4 +180,4 @@ def lora_merge(lora_a, lora_b, x, W_base, W_merge): func_lora = partial( lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights ) - # benchmark_module("LoRA merge (our current vanila)", func_lora) + benchmark_module("LoRA merge (our current vanila)", func_lora) diff --git a/mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py b/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py similarity index 91% rename from mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py rename to mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py index 90877c82a..dcf4c110f 100644 --- a/mttl/models/modifiers/spasity/spb_moe/bsr_adapter_moe_test.py +++ b/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py @@ -10,8 +10,8 @@ from pytorch_lightning import seed_everything from stk.matrix import Matrix -from mttl.models.modifiers.spasity.sparse_utils import stk_matrix_utils as matrix_ops -from mttl.models.modifiers.spasity.spb_moe import linear_ops +from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops +from mttl.models.modifiers.sparsity.spb_moe import ops # os.environ["TRITON_INTERPRET"] = "1" # @@ -75,10 +75,10 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): ) k_weights, expert_idxs = torch.topk(weights, k) - sorted_expert_idxs, sorted_scattered_idxs = linear_ops.flatten_and_sort( + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( expert_idxs ) - padded_block_idxs, expert_offsets = linear_ops.padded_block_indices( + padded_block_idxs, expert_offsets = ops.padded_block_indices( sorted_expert_idxs, E ) @@ -100,7 +100,7 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): # gates=k_weights, # ) - out2 = linear_ops.scattergather_adamerge_opt( + out2 = ops.scattergather_adamerge_opt( x=X, base_act=base_act, k=k, diff --git a/mttl/models/modifiers/spasity/spb_moe/functions.py b/mttl/models/modifiers/sparsity/spb_moe/functions.py similarity index 97% rename from mttl/models/modifiers/spasity/spb_moe/functions.py rename to mttl/models/modifiers/sparsity/spb_moe/functions.py index 5806575ba..cd554ff0b 100644 --- a/mttl/models/modifiers/spasity/spb_moe/functions.py +++ b/mttl/models/modifiers/sparsity/spb_moe/functions.py @@ -4,8 +4,8 @@ from stk.backend.autocast import custom_bwd, custom_fwd from stk.matrix import Matrix -import mttl.models.modifiers.spasity.spb_moe._triton_kernels as backend -from mttl.models.modifiers.spasity.spb_moe.triton_kernels import ( +import mttl.models.modifiers.sparsity.spb_moe._triton_kernels as backend +from mttl.models.modifiers.sparsity.spb_moe.triton_kernels import ( scatter2scatter_sparse, scatter2scatter_sparse_optimized, ) diff --git a/mttl/models/modifiers/spasity/spb_moe/linear_ops.py b/mttl/models/modifiers/sparsity/spb_moe/ops.py similarity index 98% rename from mttl/models/modifiers/spasity/spb_moe/linear_ops.py rename to mttl/models/modifiers/sparsity/spb_moe/ops.py index 536b27932..de400acba 100644 --- a/mttl/models/modifiers/spasity/spb_moe/linear_ops.py +++ b/mttl/models/modifiers/sparsity/spb_moe/ops.py @@ -1,7 +1,7 @@ import torch from stk.matrix import Matrix -from mttl.models.modifiers.spasity.spb_moe import functions +from mttl.models.modifiers.sparsity.spb_moe import functions def sdd_adamerge(a, b, out_topo: Matrix, out_adaps: Matrix, layout): diff --git a/mttl/models/modifiers/spasity/spb_moe/triton_kernels.py b/mttl/models/modifiers/sparsity/spb_moe/triton_kernels.py similarity index 100% rename from mttl/models/modifiers/spasity/spb_moe/triton_kernels.py rename to mttl/models/modifiers/sparsity/spb_moe/triton_kernels.py diff --git a/mttl/models/modifiers/spasity/__init__.py b/mttl/models/modifiers/spasity/__init__.py deleted file mode 100644 index 49ea5eda6..000000000 --- a/mttl/models/modifiers/spasity/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from mttl.models.modifiers.spasity.sparse_mask import * diff --git a/tests/test_sparse_masks.py b/tests/test_sparse_masks.py index 9fe45c684..fa995ec4f 100644 --- a/tests/test_sparse_masks.py +++ b/tests/test_sparse_masks.py @@ -7,7 +7,7 @@ from pytorch_lightning import seed_everything from mttl.models.modifiers import modify_transformer -from mttl.models.modifiers.spasity.sparse_mask import ( +from mttl.models.modifiers.sparsity.sparse_mask import ( MaskedLinear, ScatteredSparseLinearModule, SNIPMaskUpdateWrapper, From 3ee237d28edac4e49163c391f3800ad56add31ce Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 12:02:15 -0400 Subject: [PATCH 13/24] test --- mttl/models/modifiers/sparsity/sparse_mask.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mttl/models/modifiers/sparsity/sparse_mask.py b/mttl/models/modifiers/sparsity/sparse_mask.py index b5624aabf..02ff7cb50 100644 --- a/mttl/models/modifiers/sparsity/sparse_mask.py +++ b/mttl/models/modifiers/sparsity/sparse_mask.py @@ -38,6 +38,7 @@ class SparseMaskConfig(ModifierConfig): selection_algorithm: str = "rigl" reselection_rate_policy: str = "linear" mask_updater: str = None # "snip" + init_to_zero: bool = True class SparseLinear(ABC): @@ -134,8 +135,8 @@ def __init__(self, config: SparseMaskConfig, shape, dtype, device, **kwargs): self.register_buffer("col_idx", torch.zeros((nnz,), dtype=torch.int16)) self.set_sparse_idxs(_sparse_csr_representation) - - self.init_random() + if not config.init_to_zero: + self.init_random() def init_random(self): # init 1D tensor of values From 1c83f1084c6de2f7520d51380cf5f364a099c5a4 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 12:13:05 -0400 Subject: [PATCH 14/24] test --- tests/test_bsr_moe.py | 88 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tests/test_bsr_moe.py diff --git a/tests/test_bsr_moe.py b/tests/test_bsr_moe.py new file mode 100644 index 000000000..c40f5aedf --- /dev/null +++ b/tests/test_bsr_moe.py @@ -0,0 +1,88 @@ +import pytest +import torch +from pytorch_lightning import seed_everything +from stk.matrix import Matrix + +from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops +from mttl.models.modifiers.sparsity.spb_moe import ops + +blocksize = 16 + +SC_MOE_TEST = { + (4, 32, 64, 10, 2, 0.8, 16, torch.float32), + (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), + (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), + (8, 128, 256, 10, 2, 0.8, 16, torch.float32), +} + + +def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): + output = torch.stack( + [ + sum( + base_act[i] + + (expert_p[i, j] * torch.matmul(x[i], (adaps[expert_idxs[i, j]]))) + for j in range(expert_idxs.size(1)) + ) + for i in range(expert_idxs.size(0)) + ], + dim=0, + ) + return output + + +@pytest.mark.skipif( + torch.cuda.is_available() is False, reason="CUDA must be available for this test." +) +@pytest.mark.parametrize("bs, d, h, E, k, sparsity, blocking, dtype", SC_MOE_TEST) +def testScatteredMoE(bs, d, h, E, k, sparsity, blocking, dtype): + seed_everything(42) + device = "cuda" + # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") + logits = torch.randn(bs, E, dtype=dtype) + weights = torch.softmax(logits.float(), axis=-1).to(dtype).to(device) + X = torch.randn(bs, d, dtype=dtype, requires_grad=True).to(device) + W = torch.randn(d, h, dtype=dtype, requires_grad=True).to(device) + adaps = [ + matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E) + ] + adaps_sparse = [adap[1] for adap in adaps] + adaps_dense = [adap[0] for adap in adaps] + ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) + row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) + col_idxs_t = torch.stack([adap.column_indices_t for adap in adaps_sparse], dim=0) + offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) + block_offsets_t = torch.stack( + [adap.block_offsets_t for adap in adaps_sparse], dim=0 + ) + + k_weights, expert_idxs = torch.topk(weights, k) + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(expert_idxs) + padded_block_idxs, expert_offsets = ops.padded_block_indices(sorted_expert_idxs, E) + + base_act = torch.matmul(X, W) + out2 = ops.scattergather_adamerge_opt( + x=X, + base_act=base_act, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + gates=k_weights, + ) + + out_dumb = dumb_forward(base_act, X, k_weights, expert_idxs, adaps_dense) + err_Y = torch.abs(out2 - out_dumb) + tolerance = 1e-2 + print(err_Y.max()) + assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() + + +if __name__ == "__main__": + pytest.main([__file__]) From 1b68a3a19e27d66b7e57efd9177c02e826c64c72 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 12:15:02 -0400 Subject: [PATCH 15/24] black --- mttl/models/modifiers/sparsity/spb_moe/benchmark.py | 4 +--- .../models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mttl/models/modifiers/sparsity/spb_moe/benchmark.py b/mttl/models/modifiers/sparsity/spb_moe/benchmark.py index bfe09faf3..aae3897ac 100644 --- a/mttl/models/modifiers/sparsity/spb_moe/benchmark.py +++ b/mttl/models/modifiers/sparsity/spb_moe/benchmark.py @@ -111,9 +111,7 @@ def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): def call_with_baseact_and_idxs_computation(X, W, expert_idxs, function, **kwargs): base_act = torch.matmul(X, W) - sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( - expert_idxs - ) + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(expert_idxs) padded_block_idxs, expert_offsets = ops.padded_block_indices( sorted_expert_idxs, E ) diff --git a/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py b/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py index dcf4c110f..9dbe5caaa 100644 --- a/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py +++ b/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py @@ -75,9 +75,7 @@ def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): ) k_weights, expert_idxs = torch.topk(weights, k) - sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( - expert_idxs - ) + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(expert_idxs) padded_block_idxs, expert_offsets = ops.padded_block_indices( sorted_expert_idxs, E ) From 306cda630c9a8ec8f1d0cc02b62d9a0e332c4b19 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 13:10:19 -0400 Subject: [PATCH 16/24] tests --- .../sparsity/sparse_utils/stk_matrix_test.py | 34 ++--- .../sparsity/spb_moe/bsr_adapter_moe_test.py | 125 ------------------ 2 files changed, 19 insertions(+), 140 deletions(-) delete mode 100644 mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py diff --git a/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py index 3a82258f5..5552f8d4a 100644 --- a/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py +++ b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py @@ -1,5 +1,4 @@ -import unittest - +import pytest import stk import stk.ops import torch @@ -9,20 +8,25 @@ from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops -@parameterized.parameters((2, 8, 16, 0.5, 1), (2, 8, 16, 0.5, 4)) -class MatrixOpsTest(parameterized.TestCase): - def test_layout_creation(self, K, rows, cols, sparsity, blocking): - adaps = [ - matrix_ops._dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) - for _ in range(K) - ] - adaps_sparse = [adap[1] for adap in adaps] - # adaps_dense = [adap[0] for adap in adaps] +@pytest.mark.parametrize( + "K, rows, cols, sparsity, blocking", + [ + (2, 8, 16, 0.5, 1), + (2, 8, 16, 0.5, 4), + ], +) +def test_layout_creation(self, K, rows, cols, sparsity, blocking): + adaps = [ + matrix_ops._dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) + for _ in range(K) + ] + adaps_sparse = [adap[1] for adap in adaps] + # adaps_dense = [adap[0] for adap in adaps] - merged_adaps_matrix: Matrix = matrix_ops.merge_adapters(adaps_sparse) - layout = matrix_ops.create_ada_layout(merged_adaps_matrix) - assert layout.max() == merged_adaps_matrix.data.size(0) - 1 + merged_adaps_matrix: Matrix = matrix_ops.merge_adapters(adaps_sparse) + layout = matrix_ops.create_ada_layout(merged_adaps_matrix) + assert layout.max() == merged_adaps_matrix.data.size(0) - 1 if __name__ == "__main__": - unittest.main() + pytest.main([__file__]) diff --git a/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py b/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py deleted file mode 100644 index 9dbe5caaa..000000000 --- a/mttl/models/modifiers/sparsity/spb_moe/bsr_adapter_moe_test.py +++ /dev/null @@ -1,125 +0,0 @@ -import itertools -import os -import unittest - -import numpy as np -import stk -import torch -import torch.nn.functional as F -from absl.testing import parameterized -from pytorch_lightning import seed_everything -from stk.matrix import Matrix - -from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops -from mttl.models.modifiers.sparsity.spb_moe import ops - -# os.environ["TRITON_INTERPRET"] = "1" # - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -blocksize = 16 - -SC_MOE_TEST = { - (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), - (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), - (8, 128, 256, 10, 2, 0.8, 16, torch.float32), -} - - -def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): - output = torch.stack( - [ - sum( - base_act[i] - + (expert_p[i, j] * torch.matmul(x[i], (adaps[expert_idxs[i, j]]))) - for j in range(expert_idxs.size(1)) - ) - for i in range(expert_idxs.size(0)) - ], - dim=0, - ) - return output - - -@parameterized.parameters(*SC_MOE_TEST) -class ScatteredMoETest(parameterized.TestCase): - def testScatteredMoE(self, bs, d, h, E, k, sparsity, blocking, dtype): - torch.manual_seed(42) - # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") - logits = torch.randn(bs, E, dtype=dtype) - weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) - X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() - W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() - adaps = [ - matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) - for _ in range(E) - ] - adaps_sparse = [adap[1] for adap in adaps] - adaps_dense = [adap[0] for adap in adaps] - ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) - row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) - col_idxs_t = torch.stack( - [adap.column_indices_t for adap in adaps_sparse], dim=0 - ) - offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) - block_offsets_t = torch.stack( - [adap.block_offsets_t for adap in adaps_sparse], dim=0 - ) - - k_weights, expert_idxs = torch.topk(weights, k) - sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(expert_idxs) - padded_block_idxs, expert_offsets = ops.padded_block_indices( - sorted_expert_idxs, E - ) - - base_act = torch.matmul(X, W) - - # out = linear_ops.scattergather_adamerge( - # x=X, - # base_act=base_act, - # k=k, - # ada_weights=ada_data, - # row_idxs=row_idxs, - # col_idxs=col_idxs_t, - # offsets=offsets_t, - # block_offsets_t=block_offsets_t, - # ada_block_size=blocking, - # sorted_expert_idxs=sorted_expert_idxs, - # sorted_scattered_idxs=sorted_scattered_idxs, - # padded_block_idxs=padded_block_idxs, - # gates=k_weights, - # ) - - out2 = ops.scattergather_adamerge_opt( - x=X, - base_act=base_act, - k=k, - ada_weights=ada_data, - row_idxs=row_idxs, - col_idxs=col_idxs_t, - offsets=offsets_t, - block_offsets_t=block_offsets_t, - ada_block_size=blocking, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, - gates=k_weights, - ) - - out_dumb = dumb_forward(base_act, X, k_weights, expert_idxs, adaps_dense) - err_Y = torch.abs(out2 - out_dumb) - tolerance = 1e-2 - print(err_Y.max()) - assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() - - -if __name__ == "__main__": - unittest.main() From 98aabcea65328fbe0cf35aaa5430bf7f5991b7e2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 13:30:02 -0400 Subject: [PATCH 17/24] cleaned --- .../modifiers/sparsity/spb_moe/functions.py | 100 +----------------- 1 file changed, 2 insertions(+), 98 deletions(-) diff --git a/mttl/models/modifiers/sparsity/spb_moe/functions.py b/mttl/models/modifiers/sparsity/spb_moe/functions.py index cd554ff0b..023e5abae 100644 --- a/mttl/models/modifiers/sparsity/spb_moe/functions.py +++ b/mttl/models/modifiers/sparsity/spb_moe/functions.py @@ -4,108 +4,12 @@ from stk.backend.autocast import custom_bwd, custom_fwd from stk.matrix import Matrix -import mttl.models.modifiers.sparsity.spb_moe._triton_kernels as backend from mttl.models.modifiers.sparsity.spb_moe.triton_kernels import ( scatter2scatter_sparse, scatter2scatter_sparse_optimized, ) -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device, - ) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply - - -class SDD_SpMerge(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx, - lhs, - rhs, - shape, - data, - row_indices, - column_indices, - column_indices_t, - block_offsets_t, - adap_data, - ada_maping, - ): - # note for later: here we will need ofdfsets transpose and offsets for the baclkward pass if we implement it - out = torch.empty(data.shape, dtype=lhs.dtype, device=lhs.device) - backend.sdd_spmerge( - lhs, rhs, shape, out, row_indices, column_indices, adap_data, ada_maping - ) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - raise NotImplementedError - - -sdd_spsmerge = SDD_SpMerge.apply - - -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return backend.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = backend.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply - - class ParalleLinear(torch.autograd.Function): @staticmethod @@ -151,7 +55,7 @@ def forward( parallel_linear = ParalleLinear.apply -class ParalleLinear2(torch.autograd.Function): +class ParalleLinear_optim(torch.autograd.Function): @staticmethod @custom_fwd @@ -193,4 +97,4 @@ def forward( return output -parallel_linear_optimized = ParalleLinear2.apply +parallel_linear_optimized = ParalleLinear_optim.apply From 8fe4540328aca18edd9588a93fa4ae4c4e208808 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 13:56:19 -0400 Subject: [PATCH 18/24] test --- mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py index 5552f8d4a..5974bfbe5 100644 --- a/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py +++ b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py @@ -2,7 +2,6 @@ import stk import stk.ops import torch -from absl.testing import parameterized from stk.matrix import Matrix from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops From 8763136adf30a6d469d887a2614113fb224f6001 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 14:26:08 -0400 Subject: [PATCH 19/24] tests --- .../modifiers/sparsity/spb_moe/benchmark.py | 215 ++++++++++-------- tests/test_bsr_moe.py | 16 +- .../test_stk_matrix.py | 2 +- 3 files changed, 119 insertions(+), 114 deletions(-) rename mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py => tests/test_stk_matrix.py (92%) diff --git a/mttl/models/modifiers/sparsity/spb_moe/benchmark.py b/mttl/models/modifiers/sparsity/spb_moe/benchmark.py index aae3897ac..5c57ecc14 100644 --- a/mttl/models/modifiers/sparsity/spb_moe/benchmark.py +++ b/mttl/models/modifiers/sparsity/spb_moe/benchmark.py @@ -4,14 +4,26 @@ import numpy as np import stk import torch -import torch.nn.functional as F -from absl.testing import parameterized from pytorch_lightning import seed_everything from stk.matrix import Matrix from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops from mttl.models.modifiers.sparsity.spb_moe import ops -from mttl.models.modifiers.sparsity.spb_moe.bsr_adapter_moe_test import dumb_forward + + +def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): + output = torch.stack( + [ + sum( + base_act[i] + + (expert_p[i, j] * torch.matmul(x[i], (adaps[expert_idxs[i, j]]))) + for j in range(expert_idxs.size(1)) + ) + for i in range(expert_idxs.size(0)) + ], + dim=0, + ) + return output def benchmark_module(name, function, runs=100): @@ -70,7 +82,7 @@ def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): return int(np.mean(lora_ranks)) -SC_MOE_TEST = { +MOE_TESTCASES = { # bs, d, h, E, k, sparsity, blocking, dtype (1024, 2048, 8192, 20, 2, 0.995, 16, torch.float16), (1024, 2048, 8192, 20, 2, 0.9, 128, torch.float16), @@ -81,101 +93,108 @@ def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): # (8, 128, 256, 10, 2, 0.8, 16, torch.float16), } +if __name__ == "__main__": + for bs, d, h, E, k, sparsity, blocking, dtype in MOE_TESTCASES: + print("=====================================================================") + print( + f"***** Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype} *****" + ) -for bs, d, h, E, k, sparsity, blocking, dtype in SC_MOE_TEST: - print("=====================================================================") - print( - f"***** Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype} *****" - ) - - torch.manual_seed(42) - # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") - logits = torch.randn(bs, E, dtype=dtype) - weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) - X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() - W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() - adaps = [ - matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E) - ] - adaps_sparse = [adap[1] for adap in adaps] - adaps_dense = [adap[0] for adap in adaps] - ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) - row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) - col_idxs_t = torch.stack([adap.column_indices_t for adap in adaps_sparse], dim=0) - offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) - block_offsets_t = torch.stack( - [adap.block_offsets_t for adap in adaps_sparse], dim=0 - ) - - k_weights, expert_idxs = torch.topk(weights, k) - - def call_with_baseact_and_idxs_computation(X, W, expert_idxs, function, **kwargs): - base_act = torch.matmul(X, W) - sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(expert_idxs) - padded_block_idxs, expert_offsets = ops.padded_block_indices( - sorted_expert_idxs, E + torch.manual_seed(42) + # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") + logits = torch.randn(bs, E, dtype=dtype) + weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) + X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() + W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() + adaps = [ + matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) + for _ in range(E) + ] + adaps_sparse = [adap[1] for adap in adaps] + adaps_dense = [adap[0] for adap in adaps] + ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) + row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) + col_idxs_t = torch.stack( + [adap.column_indices_t for adap in adaps_sparse], dim=0 ) - return function( - x=X, - base_act=base_act, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - padded_block_idxs=padded_block_idxs, - **kwargs, + offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) + block_offsets_t = torch.stack( + [adap.block_offsets_t for adap in adaps_sparse], dim=0 ) - # base_act = torch.matmul(X, W) - func = partial( - call_with_baseact_and_idxs_computation, - X=X, - W=W, - expert_idxs=expert_idxs, - function=ops.scattergather_adamerge, - k=k, - ada_weights=ada_data, - row_idxs=row_idxs, - col_idxs=col_idxs_t, - offsets=offsets_t, - block_offsets_t=block_offsets_t, - ada_block_size=blocking, - gates=k_weights, - ) - benchmark_module("BS kernel not optimized", func) - # func_dummb = partial(dumb_forward, base_act=base_act, x=X, expert_p=k_weights, expert_idxs=expert_idxs, adaps=adaps_dense) - # benchmark_module("dummy forward", func_dummb) - - func_opt = partial( - call_with_baseact_and_idxs_computation, - X=X, - W=W, - expert_idxs=expert_idxs, - function=ops.scattergather_adamerge_opt, - k=k, - ada_weights=ada_data, - row_idxs=row_idxs, - col_idxs=col_idxs_t, - offsets=offsets_t, - block_offsets_t=block_offsets_t, - ada_block_size=blocking, - gates=k_weights, - ) - benchmark_module("BS kernel optimized", func) - lora_rank = find_lora_hyperpaams(d, h, np.prod(ada_data.shape[1:])) - - def lora_merge(lora_a, lora_b, x, W_base, W_merge): - # LoRA does not profit from lower top-k in this vanila form - # merge into 1 lora - A = torch.einsum("be,edr->bdr", (W_merge, lora_a)) - B = torch.einsum("be,erd->brd", (W_merge, lora_b)) - # lora forward - partial_out = torch.einsum("bd,bdr->br", (x, A)) - adapter_out = torch.einsum("br,brd->bd", (partial_out, B)) - dense_out = x @ W_base - return adapter_out + dense_out - - lora_a = torch.randn(E, d, lora_rank, dtype=dtype).cuda().contiguous() - lora_b = torch.randn(E, lora_rank, h, dtype=dtype).cuda().contiguous() - func_lora = partial( - lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights - ) - benchmark_module("LoRA merge (our current vanila)", func_lora) + k_weights, expert_idxs = torch.topk(weights, k) + + def call_with_baseact_and_idxs_computation( + X, W, expert_idxs, function, **kwargs + ): + base_act = torch.matmul(X, W) + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( + expert_idxs + ) + padded_block_idxs, expert_offsets = ops.padded_block_indices( + sorted_expert_idxs, E + ) + return function( + x=X, + base_act=base_act, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + **kwargs, + ) + + # base_act = torch.matmul(X, W) + func = partial( + call_with_baseact_and_idxs_computation, + X=X, + W=W, + expert_idxs=expert_idxs, + function=ops.scattergather_adamerge, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + gates=k_weights, + ) + benchmark_module("BS kernel not optimized", func) + # func_dummb = partial(dumb_forward, base_act=base_act, x=X, expert_p=k_weights, expert_idxs=expert_idxs, adaps=adaps_dense) + # benchmark_module("dummy forward", func_dummb) + + func_opt = partial( + call_with_baseact_and_idxs_computation, + X=X, + W=W, + expert_idxs=expert_idxs, + function=ops.scattergather_adamerge_opt, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + gates=k_weights, + ) + benchmark_module("BS kernel optimized", func) + lora_rank = find_lora_hyperpaams(d, h, np.prod(ada_data.shape[1:])) + + def lora_merge(lora_a, lora_b, x, W_base, W_merge): + # LoRA does not profit from lower top-k in this vanila form + # merge into 1 lora + A = torch.einsum("be,edr->bdr", (W_merge, lora_a)) + B = torch.einsum("be,erd->brd", (W_merge, lora_b)) + # lora forward + partial_out = torch.einsum("bd,bdr->br", (x, A)) + adapter_out = torch.einsum("br,brd->bd", (partial_out, B)) + dense_out = x @ W_base + return adapter_out + dense_out + + lora_a = torch.randn(E, d, lora_rank, dtype=dtype).cuda().contiguous() + lora_b = torch.randn(E, lora_rank, h, dtype=dtype).cuda().contiguous() + func_lora = partial( + lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights + ) + benchmark_module("LoRA merge (our current vanila)", func_lora) diff --git a/tests/test_bsr_moe.py b/tests/test_bsr_moe.py index c40f5aedf..8266cdc7a 100644 --- a/tests/test_bsr_moe.py +++ b/tests/test_bsr_moe.py @@ -5,6 +5,7 @@ from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops from mttl.models.modifiers.sparsity.spb_moe import ops +from mttl.models.modifiers.sparsity.spb_moe.benchmark import dumb_forward blocksize = 16 @@ -16,21 +17,6 @@ } -def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): - output = torch.stack( - [ - sum( - base_act[i] - + (expert_p[i, j] * torch.matmul(x[i], (adaps[expert_idxs[i, j]]))) - for j in range(expert_idxs.size(1)) - ) - for i in range(expert_idxs.size(0)) - ], - dim=0, - ) - return output - - @pytest.mark.skipif( torch.cuda.is_available() is False, reason="CUDA must be available for this test." ) diff --git a/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py b/tests/test_stk_matrix.py similarity index 92% rename from mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py rename to tests/test_stk_matrix.py index 5974bfbe5..ad7b8fb85 100644 --- a/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_test.py +++ b/tests/test_stk_matrix.py @@ -14,7 +14,7 @@ (2, 8, 16, 0.5, 4), ], ) -def test_layout_creation(self, K, rows, cols, sparsity, blocking): +def test_layout_creation(K, rows, cols, sparsity, blocking): adaps = [ matrix_ops._dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) for _ in range(K) From ab4802a89de2ffb2aa54a6ee941a4cba566babce Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 14:40:07 -0400 Subject: [PATCH 20/24] do not use sets when parametrizing tests --- tests/test_bsr_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_bsr_moe.py b/tests/test_bsr_moe.py index 8266cdc7a..bf0d46997 100644 --- a/tests/test_bsr_moe.py +++ b/tests/test_bsr_moe.py @@ -9,12 +9,12 @@ blocksize = 16 -SC_MOE_TEST = { +SC_MOE_TEST = [ (4, 32, 64, 10, 2, 0.8, 16, torch.float32), (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), (8, 128, 256, 10, 2, 0.8, 16, torch.float32), -} +] @pytest.mark.skipif( From 6bba8af10bb8644e7876f30ad4f5c9bef2757cb9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 4 Oct 2024 15:38:44 -0400 Subject: [PATCH 21/24] nvm --- tests/test_stk_matrix.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_stk_matrix.py b/tests/test_stk_matrix.py index ad7b8fb85..792b071fe 100644 --- a/tests/test_stk_matrix.py +++ b/tests/test_stk_matrix.py @@ -7,6 +7,9 @@ from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops +@pytest.mark.skipif( + torch.cuda.is_available() is False, reason="CUDA must be available for this test." +) @pytest.mark.parametrize( "K, rows, cols, sparsity, blocking", [ From 4487c0b726cc7b7139a97a94b45fc5039fcc7833 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 6 Nov 2024 10:07:46 -0500 Subject: [PATCH 22/24] cleanup --- mttl/models/modifiers/sm_updater.py | 4 +- mttl/models/modifiers/sparse_mask.py | 8 +- .../{sm_config.py => sparse_mask_config.py} | 2 +- .../models/modifiers/sparsity/mask_updater.py | 156 ++++++++++++++ mttl/models/modifiers/sparsity/sm_updater.py | 197 ++++++++++++++++++ .../{sparse_utils => }/sparse_linear.py | 55 +---- mttl/models/modifiers/sparsity/sparse_mask.py | 4 +- tests/test_sparse_masks.py | 2 +- 8 files changed, 367 insertions(+), 61 deletions(-) rename mttl/models/modifiers/{sm_config.py => sparse_mask_config.py} (88%) create mode 100644 mttl/models/modifiers/sparsity/mask_updater.py create mode 100644 mttl/models/modifiers/sparsity/sm_updater.py rename mttl/models/modifiers/sparsity/{sparse_utils => }/sparse_linear.py (93%) diff --git a/mttl/models/modifiers/sm_updater.py b/mttl/models/modifiers/sm_updater.py index a978a85f3..6389325c5 100644 --- a/mttl/models/modifiers/sm_updater.py +++ b/mttl/models/modifiers/sm_updater.py @@ -11,8 +11,8 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier -from mttl.models.modifiers.sm_config import SparseMaskConfig -from mttl.models.modifiers.sparsity.sparse_utils.sparse_linear import ( +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.sparse_linear import ( MaskedLinear, SparseLinear, ) diff --git a/mttl/models/modifiers/sparse_mask.py b/mttl/models/modifiers/sparse_mask.py index 5bb42fbb2..82c77d8fd 100644 --- a/mttl/models/modifiers/sparse_mask.py +++ b/mttl/models/modifiers/sparse_mask.py @@ -8,9 +8,9 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig -from mttl.models.modifiers.sm_config import SparseMaskConfig -from mttl.models.modifiers.sm_updater import MaskUpdater -from mttl.models.modifiers.sparsity.sparse_utils.sparse_linear import ( +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.mask_updater import MaskUpdater +from mttl.models.modifiers.sparsity.sparse_linear import ( MaskedLinear, ScatteredSparseLinearModule, SparseLinear, @@ -46,7 +46,7 @@ def __init__( def forward(self, input): if self.maks_update_mode and self.training: - return self.mask_updater(self.sparse_layer, input) + return self.mask_updater(input) return self.sparse_layer(input) def prepare_for_mask_update(self): diff --git a/mttl/models/modifiers/sm_config.py b/mttl/models/modifiers/sparse_mask_config.py similarity index 88% rename from mttl/models/modifiers/sm_config.py rename to mttl/models/modifiers/sparse_mask_config.py index 4c309f911..4261392fd 100644 --- a/mttl/models/modifiers/sm_config.py +++ b/mttl/models/modifiers/sparse_mask_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from mttl.models.modifiers.sparsity.sparse_utils.sparse_linear import SparseLinearConfig +from mttl.models.modifiers.sparsity.sparse_linear import SparseLinearConfig @dataclass diff --git a/mttl/models/modifiers/sparsity/mask_updater.py b/mttl/models/modifiers/sparsity/mask_updater.py new file mode 100644 index 000000000..9d3880937 --- /dev/null +++ b/mttl/models/modifiers/sparsity/mask_updater.py @@ -0,0 +1,156 @@ +import torch +from scipy.sparse import csr_matrix +from torch import nn + +from mttl.logging import logger +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.sparse_linear import MaskedLinear, SparseLinear +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( + get_2d_indices_from_csr_matrix, + get_top_k_sparcity, + scipy_csr_to_torch_csr, + torch_csr_to_scipy_csr, +) +from mttl.registrable import Registrable + + +class MaskUpdater(nn.Module, Registrable): + def __init__(self, config: SparseMaskConfig): + super().__init__() + self.config = config + + +@MaskUpdater.register("snip", config_cls=SparseMaskConfig) +class SNIPMaskUpdater(MaskUpdater): + """ + It is used to periodically re-calculate the sparse mask indices a la SNIP (https://arxiv.org/pdf/1810.02340). + To recalculate the mask, it uses ONE infoming batch to estimate the importance of each parameter. + + It accumulates learned weights in a dense CPU matrix. For MaskedLinear implementation this accumulation is already done in MaskedLinear class, since sparse mask is kept in dense format. + This accumulation is useful e.g. to make sure that the weights that have been learned in the past and are selected again are not reinitialized to 0. + """ + + def __init__( + self, config: SparseMaskConfig, base_weights_shape, base_weights_shape_dtype + ): + super().__init__(config) + + self.keep_ratio = config.keep_ratio + self.block_size = config.block_size + + self._n_mask_updates = 0 + self.updating_the_mask = False + + self.binary_mask = None + self._selected_indices = None + self._backward_hooks = [] + self.sparse_layer_weights, self.sparse_layer_biases = None, None + + # sparse weights for accumulation on CPU + self.accumulated_sparse_weights = torch.zeros( + base_weights_shape, device="cpu", dtype=base_weights_shape_dtype + ) + + def switch_to_mask_update_mode(self, sparse_layer: SparseLinear): + self.updating_the_mask = True + self._selected_indices = None + base_weights, base_biases, sparse_weights, sparse_biases = ( + sparse_layer.get_weights_for_mask_learning() + ) + if isinstance(sparse_layer, MaskedLinear): + # here we already keep sparse weights as dense matrix, so accumulation in SNIP is not needed + self.sparse_layer_weights = base_weights + sparse_weights + else: + assert isinstance(sparse_weights, csr_matrix) + # need to do two things: + # 1. keep track of accumulated sparse weights + # 2. Merge those accumulated weight deltas into the base weights and use them for importance estimation + r, c = get_2d_indices_from_csr_matrix(sparse_weights) + if len(r) > 0: + self.accumulated_sparse_weights[r, c] = torch.tensor( + sparse_weights[r, c], + dtype=self.accumulated_sparse_weights.dtype, + device="cpu", + ) + self.sparse_layer_weights = ( + base_weights + self.accumulated_sparse_weights.to(base_weights.device) + ) + + self.sparse_layer_biases = base_biases + if sparse_biases is not None: + if self.sparse_layer_biases is None: + self.sparse_layer_biases = sparse_biases.detach() + else: + self.sparse_layer_biases += sparse_biases.detach() + + self.binary_mask = torch.ones_like( + self.sparse_layer_weights, device=self.sparse_layer_weights.device + ) + self.binary_mask.requires_grad = True + + def mask_backward_hook(mask): + selected_params_dense = get_top_k_sparcity( + mask.grad, self.config.sps_type, self.keep_ratio, self.block_size + ) + selected_params = selected_params_dense.float().to_sparse_csr() # .cpu() + if self._selected_indices == None: + self._selected_indices = selected_params # .coalesce() + else: + self._selected_indices += selected_params + self._selected_indices = self._selected_indices # .coalesce() + + mask.grad = None # be efficient, throw aways the grads + return None + + hook_handle = self.binary_mask.register_post_accumulate_grad_hook( + mask_backward_hook + ) + self._backward_hooks.append(hook_handle) + + def switch_to_weights_update_mode(self, sparse_layer: SparseLinear): + self.unregister_hooks() + self.updating_the_mask = False + self.sparse_layer_weights, self.sparse_layer_biases = None, None + # update the mask of the sparse layer + # SNIP weight accumulation: we set the newly selected weights to zeros, + # but weights that have been already learned in the past are kept + if isinstance(sparse_layer, MaskedLinear): + new_weights = self.selected_indices + else: + # other sparse layers than MaskedLinear, do not accumulate weights + # so its handeled here + new_weights = self.selected_indices + new_weights = torch_csr_to_scipy_csr(new_weights) + r, c = get_2d_indices_from_csr_matrix(new_weights) + new_weights *= 0.0 + new_weights[r, c] = self.accumulated_sparse_weights[r, c].float() + new_weights = scipy_csr_to_torch_csr(new_weights) + + sparse_layer.reset_sparse_weights(new_weights) + self._selected_indices = None + self.binary_mask = None + self._n_mask_updates += 1 + + @property + def selected_indices(self) -> torch.Tensor: + if self.config.steps_in_mask_selection == 1: + return self._selected_indices + raise NotImplementedError("More than one step in mask selection is not supported") + + def forward(self, x: torch.Tensor): + input_dtype = x.dtype + x = x.to(self.sparse_layer_weights.dtype) + bias = ( + self.sparse_layer_biases.detach().to(self.sparse_layer_weights.dtype) + if self.sparse_layer_biases is not None + else None + ) + assert self.sparse_layer_weights is not None + return torch.nn.functional.linear( + x, self.sparse_layer_weights.detach() * self.binary_mask, bias + ).to(input_dtype) + + def unregister_hooks(self): + for hook in self._backward_hooks: + hook.remove() + self._backward_hooks = [] diff --git a/mttl/models/modifiers/sparsity/sm_updater.py b/mttl/models/modifiers/sparsity/sm_updater.py new file mode 100644 index 000000000..a238990ec --- /dev/null +++ b/mttl/models/modifiers/sparsity/sm_updater.py @@ -0,0 +1,197 @@ +from abc import ABC, abstractmethod +from collections import namedtuple +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +from scipy.sparse import csr_matrix +from torch import nn +from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut + +from mttl.logging import logger +from mttl.models.modifiers.base import Modifier +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.sparse_linear import ( + MaskedLinear, + SparseLinear, +) +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( + get_2d_indices_from_csr_matrix, + get_top_k_sparcity, + scipy_csr_to_torch_csr, + torch_csr_to_scipy_csr, +) +from mttl.registrable import Registrable + + +class MaskUpdater(nn.Module, Registrable): + def __init__(self, config: SparseMaskConfig): + super().__init__() + self.config = config + + +@MaskUpdater.register("snip", config_cls=SparseMaskConfig) +class SNIPMaskUpdater(MaskUpdater): + """ + It is used to periodically re-calculate the sparse mask indices a la SNIP (https://arxiv.org/pdf/1810.02340). + To recalculate the mask, it uses a couple of incoming mini-batches to estimate the importance of each parameter. + + It accumulates learned weights in a dense CPU matrix. + This is useful e.g. to make sure that the weights that have been learned in the past and are selected again are not reinitialized to 0. + """ + + def __init__( + self, config: SparseMaskConfig, base_weights_shape, base_weights_shape_dtype + ): + super().__init__(config) + + self.keep_ratio = config.keep_ratio + self.block_size = config.block_size + + self._steps_since_last_mask_update = int(config.skip_zeros_mask_update) + self._mask_update_steps = 0 + self._n_mask_updates = 0 + + self.updating_the_mask = False + + self.binary_mask = None + self._selected_indices = None + self._backward_hooks = [] + self.sparse_layer_weights, self.sparse_layer_biases = None, None + + # sparse weights for accumulation on CPU + self.accumulated_sparse_weights = torch.zeros( + base_weights_shape, device="cpu", dtype=base_weights_shape_dtype + ) + + def switch_to_mask_update_mode(self, sparse_layer): + self.updating_the_mask = True + self._selected_indices = None + base_weights, base_biases, sparse_weights, sparse_biases = ( + sparse_layer.get_weights_for_mask_learning() + ) + if isinstance(sparse_layer, MaskedLinear): + # here we already keep sparse weights as dense matrix, so accumulation in SNIP is not needed + self.sparse_layer_weights = base_weights + sparse_weights + else: + assert isinstance(sparse_weights, csr_matrix) + # need to do two things: + # 1. keep track of accumulated sparse weights + # 2. Merge those accumulated weight deltas into the base weights and use them for importance estimation + r, c = get_2d_indices_from_csr_matrix(sparse_weights) + if len(r) > 0: + self.accumulated_sparse_weights[r, c] = torch.tensor( + sparse_weights[r, c], + dtype=self.accumulated_sparse_weights.dtype, + device="cpu", + ) + self.sparse_layer_weights = ( + base_weights + self.accumulated_sparse_weights.to(base_weights.device) + ) + + self.sparse_layer_biases = base_biases + if sparse_biases is not None: + if self.sparse_layer_biases is None: + self.sparse_layer_biases = sparse_biases.detach() + else: + self.sparse_layer_biases += sparse_biases.detach() + + self.binary_mask = torch.ones_like( + self.sparse_layer_weights, device=self.sparse_layer_weights.device + ) + self.binary_mask.requires_grad = True + + def mask_backward_hook(mask): + selected_params_dense = get_top_k_sparcity( + mask.grad, self.config.sps_type, self.keep_ratio, self.block_size + ) + selected_params = selected_params_dense.float().to_sparse_csr() # .cpu() + if self._selected_indices == None: + self._selected_indices = selected_params # .coalesce() + else: + self._selected_indices += selected_params + self._selected_indices = self._selected_indices # .coalesce() + + mask.grad = None # be efficient, throw aways the grads + return None + + hook_handle = self.binary_mask.register_post_accumulate_grad_hook( + mask_backward_hook + ) + self._backward_hooks.append(hook_handle) + + def switch_to_weights_update_mode(self, sparse_layer: SparseLinear): + self.unregister_hooks() + self.updating_the_mask = False + self.sparse_layer_weights, self.sparse_layer_biases = None, None + # update the mask of the sparse layer + # SNIP weight accumulation: we set the newly selected weights to zeros, + # but weights that have been already learned in the past are kept + if isinstance(sparse_layer, MaskedLinear): + new_weights = self.selected_indices + else: + # other sparse layers than MaskedLinear, do not accumulate weights + # so its handeled here + new_weights = self.selected_indices + new_weights = torch_csr_to_scipy_csr(new_weights) + r, c = get_2d_indices_from_csr_matrix(new_weights) + new_weights *= 0.0 + new_weights[r, c] = self.accumulated_sparse_weights[r, c].float() + new_weights = scipy_csr_to_torch_csr(new_weights) + + sparse_layer.reset_sparse_weights(new_weights) + self._selected_indices = None + self.binary_mask = None + self._n_mask_updates += 1 + + @property + def selected_indices(self) -> torch.Tensor: + if self.config.steps_in_mask_selection == 1: + return self._selected_indices + raise NotImplementedError("More than one step in mask selection is not supported") + + def prepare_mask_or_weights_learning(self, sparse_layer: SparseLinear): + """ + Currently we have two regimes that we alternate: + - mask learning: update the non-zero indices + - weight learning: update the sparse weights + + Here we figure out what regume we are in. + """ + if self._time_to_update_mask(sparse_layer) and not self.updating_the_mask: + self.switch_to_mask_update_mode(sparse_layer) + self._mask_update_steps += 1 + + elif self.updating_the_mask and not self._time_to_update_sparse_weights( + sparse_layer + ): + self._mask_update_steps += 1 + + elif self.updating_the_mask and self._time_to_update_sparse_weights( + sparse_layer + ): + self.switch_to_weights_update_mode(sparse_layer) + self._mask_update_steps = 0 + self._steps_since_last_mask_update = 0 + + if not self.updating_the_mask: + self._steps_since_last_mask_update += 1 + + def forward(self, sparse_layer: SparseLinear, x: torch.Tensor): + input_dtype = x.dtype + x = x.to(self.sparse_layer_weights.dtype) + bias = ( + self.sparse_layer_biases.detach().to(self.sparse_layer_weights.dtype) + if self.sparse_layer_biases is not None + else None + ) + assert self.sparse_layer_weights is not None + return torch.nn.functional.linear( + x, self.sparse_layer_weights.detach() * self.binary_mask, bias + ).to(input_dtype) + + def unregister_hooks(self): + for hook in self._backward_hooks: + hook.remove() + self._backward_hooks = [] diff --git a/mttl/models/modifiers/sparsity/sparse_utils/sparse_linear.py b/mttl/models/modifiers/sparsity/sparse_linear.py similarity index 93% rename from mttl/models/modifiers/sparsity/sparse_utils/sparse_linear.py rename to mttl/models/modifiers/sparsity/sparse_linear.py index 717c66877..009f2b39d 100644 --- a/mttl/models/modifiers/sparsity/sparse_utils/sparse_linear.py +++ b/mttl/models/modifiers/sparsity/sparse_linear.py @@ -14,7 +14,6 @@ from mttl.models.modifiers.sparsity.sparse_utils.utils import ( BlcokSparseLinearFunction_SP_ADD, BlcokSparseLinearFunction_SP_SCATTER, - LinearWithSparseDelta, SparseLinearFunction_SP_ADD, _scatter_add_flattened, get_2d_indices_from_csr_matrix, @@ -360,6 +359,10 @@ def scipy_representation(self): data = self.sparse_weights.data[row_idx, col_idx].cpu().float().numpy() return csr_matrix((data, (row_idx, col_idx)), shape=self.base_weight.shape) +############# +# MaskedLinear keeps sparse weights in the dense format. THis has the advantage that we do not neet to fumble with the optimizer. +# Class below try implementing sparse layer in a memory efficient way, similar to to SpIEL (https://arxiv.org/pdf/2401.16405), which uses essentially uses ScatteredSparseLinearModule. +# Using below classes may require additional tricks like in the SpIEL paper. class SparseLinearModule(SparseWeights, SparseLinear): """ @@ -588,53 +591,3 @@ def reset_sparse_weights(self, mask: torch.Tensor): dtype=torch.int64, device=self.base_weight.device, ) - - -class SpieLSparseLinearModule(SparseLinearModule): - """ - This implements the SpIEL kernel: https://arxiv.org/pdf/2401.16405 - """ - - def __init__( - self, - weight, - bias, - config: SparseLinearConfig, - parent_name=None, - mask: torch.Tensor = None, - ): - super().__init__( - weight, - bias, - config, - parent_name, - sparse_func=LinearWithSparseDelta, - ) - indices = torch.tensor( - np.array(self.oneD_indices), - dtype=torch.int64, - device=self.base_weight.device, - ) - self.register_buffer("idxs", indices) - - @property - def oneD_indices(self): - """ - Returns a simple 1d representation of the sparse weights instead of the CSR format. - """ - twoD_indices = self.twoD_indices - return twoD_indices[0] * self.shape[1] + twoD_indices[1] - - def forward(self, input): - bias = self.base_bias - if bias and self.sparse_bias: - bias = self.base_bias + self.sparse_bias - return self.sparse_func.apply( - input, - self.base_weight, - self.sparse_weights, - self.idxs, - bias, - None, - self.base_weight.dtype, - ) diff --git a/mttl/models/modifiers/sparsity/sparse_mask.py b/mttl/models/modifiers/sparsity/sparse_mask.py index cca4973ce..e6d7d4dfb 100644 --- a/mttl/models/modifiers/sparsity/sparse_mask.py +++ b/mttl/models/modifiers/sparsity/sparse_mask.py @@ -11,8 +11,8 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig -from mttl.models.modifiers.sm_config import SparseMaskConfig -from mttl.models.modifiers.sm_updater import MaskUpdater +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.mask_updater import MaskUpdater from mttl.models.modifiers.sparse_utils.sparse_linear import ( MaskedLinear, ScatteredSparseLinearModule, diff --git a/tests/test_sparse_masks.py b/tests/test_sparse_masks.py index ab63939ea..75fc0c866 100644 --- a/tests/test_sparse_masks.py +++ b/tests/test_sparse_masks.py @@ -14,7 +14,7 @@ ScatteredSparseAdapter, ScatteredSparseLinearModule, ) -from mttl.models.modifiers.sparse_utils.sparse_linear import ScatteredSparseLinearModule +from mttl.models.modifiers.sparsity.sparse_linear import ScatteredSparseLinearModule def test_sm_adapter(): From 57792f87f1b61ac79f51eabd61caf992a8a18581 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 6 Nov 2024 10:09:47 -0500 Subject: [PATCH 23/24] black formatter --- mttl/models/modifiers/sparsity/mask_updater.py | 4 +++- mttl/models/modifiers/sparsity/sm_updater.py | 4 +++- mttl/models/modifiers/sparsity/sparse_linear.py | 2 ++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mttl/models/modifiers/sparsity/mask_updater.py b/mttl/models/modifiers/sparsity/mask_updater.py index 9d3880937..0d2908ed2 100644 --- a/mttl/models/modifiers/sparsity/mask_updater.py +++ b/mttl/models/modifiers/sparsity/mask_updater.py @@ -135,7 +135,9 @@ def switch_to_weights_update_mode(self, sparse_layer: SparseLinear): def selected_indices(self) -> torch.Tensor: if self.config.steps_in_mask_selection == 1: return self._selected_indices - raise NotImplementedError("More than one step in mask selection is not supported") + raise NotImplementedError( + "More than one step in mask selection is not supported" + ) def forward(self, x: torch.Tensor): input_dtype = x.dtype diff --git a/mttl/models/modifiers/sparsity/sm_updater.py b/mttl/models/modifiers/sparsity/sm_updater.py index a238990ec..17395221c 100644 --- a/mttl/models/modifiers/sparsity/sm_updater.py +++ b/mttl/models/modifiers/sparsity/sm_updater.py @@ -149,7 +149,9 @@ def switch_to_weights_update_mode(self, sparse_layer: SparseLinear): def selected_indices(self) -> torch.Tensor: if self.config.steps_in_mask_selection == 1: return self._selected_indices - raise NotImplementedError("More than one step in mask selection is not supported") + raise NotImplementedError( + "More than one step in mask selection is not supported" + ) def prepare_mask_or_weights_learning(self, sparse_layer: SparseLinear): """ diff --git a/mttl/models/modifiers/sparsity/sparse_linear.py b/mttl/models/modifiers/sparsity/sparse_linear.py index 009f2b39d..6c74d8d44 100644 --- a/mttl/models/modifiers/sparsity/sparse_linear.py +++ b/mttl/models/modifiers/sparsity/sparse_linear.py @@ -359,11 +359,13 @@ def scipy_representation(self): data = self.sparse_weights.data[row_idx, col_idx].cpu().float().numpy() return csr_matrix((data, (row_idx, col_idx)), shape=self.base_weight.shape) + ############# # MaskedLinear keeps sparse weights in the dense format. THis has the advantage that we do not neet to fumble with the optimizer. # Class below try implementing sparse layer in a memory efficient way, similar to to SpIEL (https://arxiv.org/pdf/2401.16405), which uses essentially uses ScatteredSparseLinearModule. # Using below classes may require additional tricks like in the SpIEL paper. + class SparseLinearModule(SparseWeights, SparseLinear): """ Implements a sparse linear layer with sparse weights and sparse backprop. From 016059847ae9953eb78e5d947426ab5eab635138 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 6 Nov 2024 10:14:40 -0500 Subject: [PATCH 24/24] isort formatter --- mttl/models/modifiers/sm_updater.py | 5 +---- mttl/models/modifiers/sparsity/sm_updater.py | 5 +---- mttl/models/modifiers/sparsity/sparse_mask.py | 2 +- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/mttl/models/modifiers/sm_updater.py b/mttl/models/modifiers/sm_updater.py index 6389325c5..ea9ffed25 100644 --- a/mttl/models/modifiers/sm_updater.py +++ b/mttl/models/modifiers/sm_updater.py @@ -12,10 +12,7 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig -from mttl.models.modifiers.sparsity.sparse_linear import ( - MaskedLinear, - SparseLinear, -) +from mttl.models.modifiers.sparsity.sparse_linear import MaskedLinear, SparseLinear from mttl.models.modifiers.sparsity.sparse_utils.utils import ( get_2d_indices_from_csr_matrix, get_top_k_sparcity, diff --git a/mttl/models/modifiers/sparsity/sm_updater.py b/mttl/models/modifiers/sparsity/sm_updater.py index 17395221c..f13bb25a3 100644 --- a/mttl/models/modifiers/sparsity/sm_updater.py +++ b/mttl/models/modifiers/sparsity/sm_updater.py @@ -12,10 +12,7 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig -from mttl.models.modifiers.sparsity.sparse_linear import ( - MaskedLinear, - SparseLinear, -) +from mttl.models.modifiers.sparsity.sparse_linear import MaskedLinear, SparseLinear from mttl.models.modifiers.sparsity.sparse_utils.utils import ( get_2d_indices_from_csr_matrix, get_top_k_sparcity, diff --git a/mttl/models/modifiers/sparsity/sparse_mask.py b/mttl/models/modifiers/sparsity/sparse_mask.py index e6d7d4dfb..a94505583 100644 --- a/mttl/models/modifiers/sparsity/sparse_mask.py +++ b/mttl/models/modifiers/sparsity/sparse_mask.py @@ -12,7 +12,6 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig -from mttl.models.modifiers.sparsity.mask_updater import MaskUpdater from mttl.models.modifiers.sparse_utils.sparse_linear import ( MaskedLinear, ScatteredSparseLinearModule, @@ -25,6 +24,7 @@ scipy_csr_to_torch_csr, torch_csr_to_scipy_csr, ) +from mttl.models.modifiers.sparsity.mask_updater import MaskUpdater class SparseMaskAdapter(Modifier):