diff --git a/CMakeLists.txt b/CMakeLists.txt index a2dfe18566fa8..546ae466eb852 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,6 +196,7 @@ set(CUSTOM_SRC "csrc/custom/custom_kernels.cu" "csrc/custom/fused_kernels.cu" "csrc/custom/custom.cu" +"csrc/custom/paged_attention/attention_ll4mi.cu" ) define_gpu_extension_target( diff --git a/ROCm_performance.md b/ROCm_performance.md index b39f3b42aab76..180c848a21950 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -12,3 +12,9 @@ The default attention function on ROCm is using triton attention kernel. To fall ## Tunable ops Pytorch tunable ops are supported. Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order to enable both the runtime tuning and the subsequent use of tuned results. To only use the tuned results without tuning any newly encountered shapes, also define `PYTORCH_TUNABLEOP_TUNING=1` + +## Custom PagedAttention + +On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`. +Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0. +The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel. diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index f71d1fcaaef50..24f734ce8cce4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,10 +6,11 @@ import torch from vllm._C import ops +from vllm._custom_C import paged_attention_custom from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 -PARTITION_SIZE = 512 +PARTITION_SIZE = 256 @torch.inference_mode() @@ -77,6 +78,9 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": + if not args.custom_paged_attn: + global PARTITION_SIZE + PARTITION_SIZE = 512 num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( @@ -118,24 +122,43 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: kv_scale, ) elif version == "v2": - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - kv_scale, - ) + if not args.custom_paged_attn: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) + else: + paged_attention_custom( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -191,6 +214,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' 'common inference criteria.') + parser.add_argument("--custom-paged-attn", + action="store_true", + help="Use custom paged attention") args = parser.parse_args() print(args) diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index aeff9cc5e6ae7..d75b2d2e41005 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -64,11 +64,36 @@ void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) { at::cuda::getCurrentCUDAStream()); } +void paged_attention_custom( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); + // declare the extension module with the AddGPU function: PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ m.doc() = "pybind11 example plugin"; m.def("LLMM1", &LLMM1); m.def("LLMM_Silu", &LLMM_Silu); m.def("LLZZ", &LLZZ); + m.def( + "paged_attention_custom", + &paged_attention_custom, + "PagedAttention LL4Mi Custom."); //m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu new file mode 100644 index 0000000000000..f94ce77c56a39 --- /dev/null +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -0,0 +1,886 @@ +//TODO: add license terms +#include +#include +#include + +#include + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +#define MAX_PARTITIONS 64 +#define WARP_SIZE 64 + +#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 +#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 + +using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +typedef float16x4 _Half4; +typedef struct _Half8 { _Half4 xy[2]; } _Half8; +////// Non temporal load stores /////// + +#if 1 + +template +__device__ __forceinline__ T load(T* addr) { + return addr[0]; +} + +template +__device__ __forceinline__ void store(T value, T* addr) { + addr[0] = value; +} + +#else + +template +__device__ __forceinline__ T load(const T* addr) { + return __builtin_nontemporal_load(addr); +} + +template <> +__device__ __forceinline__ +float2 load (const float2* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ +float4 load (const float4* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result1 = __builtin_nontemporal_load(addr_alias); + auto result2 = __builtin_nontemporal_load(addr_alias + 1); + float4 ret{}; + auto ret_alias = reinterpret_cast(&result1); + ret.x = ret_alias->x; + ret.y = ret_alias->y; + ret_alias = reinterpret_cast(&result2); + ret.z = ret_alias->x; + ret.w = ret_alias->y; + return ret; +} + +template <> +__device__ __forceinline__ +__half load (const __half* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast<__half *>(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ +__half2 load (const __half2* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast<__half2 *>(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ +vllm::Half4_ load (const vllm::Half4_* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ +vllm::Half8_ load (const vllm::Half8_* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result1 = __builtin_nontemporal_load(addr_alias); + auto result2 = __builtin_nontemporal_load(addr_alias + 1); + vllm::Half8_ ret {}; + auto ret_alias = reinterpret_cast(&result1); + ret.x = ret_alias->x; + ret.y = ret_alias->y; + ret_alias = reinterpret_cast(&result2); + ret.z = ret_alias->x; + ret.w = ret_alias->y; + return ret; +} + +//// Not using nontemporal stores for now +template +__device__ __forceinline__ void store(T value, T* addr) { + return __builtin_nontemporal_store(value, addr); +} + +#endif + +/////////////////////////////////////// + +//grid (num_seqs, num_partitions,num_heads/gqa_ratio) +//block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] +#if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] +#endif + int max_ctx_blocks + ) { + constexpr int NWARPS = NUM_THREADS/WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid%4; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int partition_size = blockDim.x; + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + const int partition_start_token_idx = partition_idx * partition_size; + //exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO,4); // each 4 lanes fetch 4 different qheads, total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4*QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4+1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4+1]; + _Half8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int HELOOP = HEAD_SIZE/x; + _Half8 Klocal[HELOOP]; + constexpr int VHLOOP = HEAD_SIZE/WARP_SIZE; //v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; //16 separate 4xtokens across warp -> 16/2 8xtokens + _Half8 Vlocal[VHLOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; + #pragma unroll + for (int h=0; h= context_len) { //warp out of context + #pragma unroll + for(int h=0;h(q_ptr); + const int qhead_elemh8 = laneid/4; + #pragma unroll + for (int h=0; h(k_ptr); + + const int physical_block_offset = local_token_idx%BLOCK_SIZE; //since x=half8, physical_block_offset is already cast as _H8 + + + #pragma unroll + for (int d=0;d(v_ptr); + const int warp_start_block_idx = warp_start_token_idx/BLOCK_SIZE; + //iterate over each v block + #pragma unroll + for (int b=0;b<8*VTLOOP/BLOCK_SIZE;b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + const int vphysical_block_number = block_table[vblock_idx_ctx]; + const _Half8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride)/8; + //iterate over each head elem (within head_size) + #pragma unroll + for (int h=0;h>2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h=0;h=4; mask/=2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h],mask)); + } + } + + float exp_sum[QHLOOP]; + #pragma unroll + for (int h=0;h=4; mask/=2) { + exp_sum[h] += __shfl_xor(exp_sum[h],mask); + } + } + + + #pragma unroll + for (int h=0;h every 4 lanes hold 4 heads, each lane holds 4 tokens, there are 4x16 tokens across warp + float16x4 logits[QHLOOP]; + #pragma unroll + for (int h=0;h= context_len) { //warp out of context + #pragma unroll + for (int qh=0; qh partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; + } + #pragma unroll + for (int qh=0; qh +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + //scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + //const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + // + head_idx * max_num_partitions * HEAD_SIZE; + //for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + // out_ptr[i] = tmp_out_ptr[i]; + //} + // Terminate the thread block. + //if num_partitions==1, main kernel will write to out directly, no work in reduction kernel + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + __shared__ float shared_global_exp_sum; + //float reg_max_logits[MAX_PARTITIONS]; //dependent on max_num_partitions: assume 32K max context div 1K Partition size -> TODO: make this proper template param + //Assume code below is optimized for MAX_PARTITIONS<=64 TODO: handle larger than warp size cases later + float* shared_max_logits = reinterpret_cast(shared_mem); + float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + //scalar_t tmp_outs[MAX_PARTITIONS]; + + // Load max logits to shared memory. + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + ////float max_logit = -FLT_MAX; + //for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + ////for (int i = threadIdx.x; i < MAX_PARTITIONS; i += blockDim.x) { + //const float l = max_logits_ptr[i]; + //shared_max_logits[i] = l; + ////reg_max_logits[i] = max_logits_ptr[i]; //TODO: review this -> right now num_partitions is very small <=32 + //max_logit = fmaxf(max_logit, l); + ////max_logit = fmaxf(max_logit, reg_max_logits[i]); + ////} + //__syncthreads(); + float max_logit = (threadIdx.x < num_partitions) ? max_logits_ptr[threadIdx.x]:-FLT_MAX; + float reg_max_logit = max_logit; + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } +// if (lane == 0) { +// red_smem[warp_idx] = max_logit; +// } + +// if (num_partitions >= WARP_SIZE) { +// __syncthreads(); +// // Reduce across warps. +// max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +//#pragma unroll +// for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +// max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); +// } +// // Broadcast the max value to all threads. +// //max_logit = __shfl(max_logit, 0); +// } + // Load rescaled exp sums to shared memory. + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + + //for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + // //float l = shared_max_logits[i]; + // //float l = reg_max_logits[i]; + // float rescaled_exp_sum = exp_sums_ptr[i] * expf(reg_max_logits[i] - max_logit); + // global_exp_sum += rescaled_exp_sum; + // shared_exp_sums[i] = rescaled_exp_sum; + //} + float rescaled_exp_sum = (threadIdx.x < num_partitions) ? exp_sums_ptr[threadIdx.x] * expf(reg_max_logit - max_logit) : 0.0f; + global_exp_sum += rescaled_exp_sum; + //if (threadIdx.x < num_partitions) { + //shared_exp_sums[threadIdx.x] = (threadIdx.x < num_partitions) ? rescaled_exp_sum : 0.0f; + shared_exp_sums[threadIdx.x] = rescaled_exp_sum; + //} + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x==0) { + shared_global_exp_sum = global_exp_sum; + } + __syncthreads(); + + //global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; +//#pragma unroll + //for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + //if (threadIdx.x < HEAD_SIZE) { //TODO: assume HEAD_SIZE < NUM_THREADS, revisit this assumption later + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; + int lastj=0; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + lastj = (j( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len,\ + alibi_slopes); + +#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ + paged_attention_ll4mi_QKV_kernel \ + <<>>( \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,out_ptr,max_ctx_blocks); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + const int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes) { + + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + //int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + //assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); +#if 0 + T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); + T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); +#endif + //constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + //int logits_size = PARTITION_SIZE * sizeof(float); + //int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + //dim3 grid(num_heads, num_seqs, max_num_partitions); + //int shared_mem_size = std::max(logits_size, outputs_size); + //// For paged attention v2 reduce kernel. + //assert(max_num_partitions <= MAX_PARTITIONS); + //assert(MAX_PARTITIONS<=head_size); + //dim3 reduce_grid(num_heads, num_seqs); + //dim3 reduce_block(head_size); //TODO: assumes max_partitions < head_SIZE + ////dim3 reduce_block(NUM_THREADS); + //int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + //dim3 grid(num_seqs,BLOCK_RATIO_PER_WG*max_ctx_blocks); + //dim3 block(num_heads*HEAD_SIZE*sizeof(T)/sizeof(float4)); + constexpr int NTHR = 256; + const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, NTHR); + //constexpr int NPAR = 2; + //constexpr int GQA_RATIO = 32; + const int gqa_ratio = num_heads/num_kv_heads; + //assert(gqa_ratio>=4); + //assert(gqa_ratio%4==0); + assert(num_heads%num_kv_heads==0); + assert(head_size==HEAD_SIZE); + dim3 grid(num_seqs,max_num_partitions,num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (gqa_ratio) { + case 1: LAUNCH_CUSTOM_ATTENTION(1); break; + case 2: LAUNCH_CUSTOM_ATTENTION(2); break; + case 3: LAUNCH_CUSTOM_ATTENTION(3); break; + case 4: LAUNCH_CUSTOM_ATTENTION(4); break; + case 5: LAUNCH_CUSTOM_ATTENTION(5); break; + case 6: LAUNCH_CUSTOM_ATTENTION(6); break; + case 7: LAUNCH_CUSTOM_ATTENTION(7); break; + case 8: LAUNCH_CUSTOM_ATTENTION(8); break; + case 9: LAUNCH_CUSTOM_ATTENTION(9); break; + case 10: LAUNCH_CUSTOM_ATTENTION(10); break; + case 11: LAUNCH_CUSTOM_ATTENTION(11); break; + case 12: LAUNCH_CUSTOM_ATTENTION(12); break; + case 13: LAUNCH_CUSTOM_ATTENTION(13); break; + case 14: LAUNCH_CUSTOM_ATTENTION(14); break; + case 15: LAUNCH_CUSTOM_ATTENTION(15); break; + case 16: LAUNCH_CUSTOM_ATTENTION(16); break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + //dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + //dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + //constexpr int PARSIZE = 256; + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); //TODO: assumes max_partitions < head_SIZE + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + paged_attention_ll4mi_reduce_kernel + <<>>( + out_ptr, + exp_sums_ptr, + max_logits_ptr, + tmp_out_ptr, + context_lens_ptr, + max_num_partitions); + //switch (head_size) { + // // NOTE(woosuk): To reduce the compilation time, we only compile for the + // // head sizes that we use in the model. However, we can easily extend this + // // to support any head size which is a multiple of 16. + // case 64: + // LAUNCH_PAGED_ATTENTION_V2(64); + // break; + // case 80: + // LAUNCH_PAGED_ATTENTION_V2(80); + // break; + // case 96: + // LAUNCH_PAGED_ATTENTION_V2(96); + // break; + // case 112: + // LAUNCH_PAGED_ATTENTION_V2(112); + // break; + // case 128: + // LAUNCH_PAGED_ATTENTION_V2(128); + // break; + // case 256: + // LAUNCH_PAGED_ATTENTION_V2(256); + // break; + // default: + // TORCH_CHECK(false, "Unsupported head size: ", head_size); + // break; + //} +} + +void paged_attention_custom( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + assert(block_size==16); + if (query.dtype() == at::ScalarType::Half) { + //CALL_V2_LAUNCHER_BLOCK_SIZE(__half); + CALL_CUSTOM_LAUNCHER(_Float16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/tests/kernels/test_attention_custom.py b/tests/kernels/test_attention_custom.py new file mode 100644 index 0000000000000..5bdbf126c22fa --- /dev/null +++ b/tests/kernels/test_attention_custom.py @@ -0,0 +1,292 @@ +import random +from typing import Optional, Tuple + +import pytest +import torch +from allclose_default import get_default_atol, get_default_rtol + +from vllm._C import cache_ops, ops +from vllm._custom_C import paged_attention_custom +from vllm.utils import get_max_shared_memory_bytes, is_hip + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 256 +# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} +DTYPES = [torch.half, torch.bfloat16, torch.float + ] if not is_hip() else [torch.half] +NUM_GEN_SEQS = [1, 17, 64] # Arbitrary values for testing +NUM_HEADS = [(8 * x, 8) for x in range(1, 17)] # Arbitrary values for testing + +# FlashAttention forward only supports head dimension at most 128 +# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 +HEAD_SIZES = [128] +BLOCK_SIZES = [16] +USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto"] +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len).int() + alibi_bias = (position_ids - context_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("version", ["custom"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + context_lens[-1] = MAX_SEQ_LEN + #context_lens = [8192 for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int) + #print('>>> ctx lens', context_lens) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + kv_scale = 1.0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version == "v1": + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) + elif version == "v2" or version == "custom": + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) + elif version == "custom": + paged_attention_custom( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + cache_ops.convert_fp8(key_cache, dequantized_key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + cache_ops.convert_fp8(value_cache, dequantized_value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if is_hip() else 1e-3 + rtol = get_default_rtol(output) if is_hip() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + atol = 5e-3 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 256bffdf032eb..72811e1468ab6 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -5,9 +6,16 @@ from vllm._C import cache_ops, ops from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.utils import is_hip + +custom_attn_available = is_hip() and \ + (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0") +if custom_attn_available: + from vllm._custom_C import paged_attention_custom # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 +_PARTITION_SIZE_V1V2 = 512 +_PARTITION_SIZE_CUSTOM = 256 @dataclass @@ -100,9 +108,17 @@ def forward_decode( kv_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) - block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape + gqa_ratio = num_heads // num_kv_heads + use_custom = (custom_attn_available and query.dtype == torch.half + and head_size == 128 and block_size == 16 + and kv_cache_dtype == "auto" + and (gqa_ratio >= 1 and gqa_ratio <= 16)) + if not use_custom: + _PARTITION_SIZE = _PARTITION_SIZE_V1V2 + else: + _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use @@ -113,7 +129,8 @@ def forward_decode( # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = (max_context_len <= 8192 - and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + and (max_num_partitions == 1 or num_seqs * num_heads > 512) + and not use_custom) if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -132,7 +149,7 @@ def forward_decode( kv_scale, ) else: - # Run PagedAttention V2. + # Run PagedAttention V2 or PagedAttention Custom. assert _PARTITION_SIZE % block_size == 0 tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), @@ -145,24 +162,43 @@ def forward_decode( device=output.device, ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - kv_scale, - ) + if not use_custom: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) + else: + paged_attention_custom( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) return output @staticmethod