diff --git a/colossalai/inference/flash_decoding_utils.py b/colossalai/inference/flash_decoding_utils.py index 8f9534d6adf4..48f43bf51622 100644 --- a/colossalai/inference/flash_decoding_utils.py +++ b/colossalai/inference/flash_decoding_utils.py @@ -16,6 +16,8 @@ def _reset(self): self._tensors_initialized = False del self._mid_output del self._mid_output_lse + del self._exp_sums + del self._max_logits @property def is_initialized(self): @@ -31,6 +33,16 @@ def mid_output_lse(self): assert self.is_initialized, "Intermediate tensors not initialized yet" return self._mid_output_lse + @property + def exp_sums(self): + assert self.is_initialized, "Intermediate tensors not initialized yet" + return self._exp_sums + + @property + def max_logits(self): + assert self.is_initialized, "Intermediate tensors not initialized yet" + return self._max_logits + def initialize( self, max_batch_size: int, @@ -60,5 +72,11 @@ def initialize( self._mid_output_lse = torch.empty( size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device ) + self._exp_sums = torch.empty( + size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device + ) + self._max_logits = torch.empty( + size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device + ) self._tensors_initialized = True diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index e6b39ccfa20d..b50e73d6fcf4 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -338,7 +338,8 @@ def forward( block_size, kv_seq_len, fd_inter_tensor.mid_output, - fd_inter_tensor.mid_output_lse, + fd_inter_tensor.exp_sums, + fd_inter_tensor.max_logits, self.alibi_slopes, sm_scale, ) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 5b8b43d4e651..9e54b7e26b09 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -596,7 +596,8 @@ def forward( block_size, kv_seq_len, fd_inter_tensor.mid_output, - fd_inter_tensor.mid_output_lse, + fd_inter_tensor.exp_sums, + fd_inter_tensor.max_logits, None, sm_scale, ) diff --git a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py index d90de6664ed6..da85f4230ac2 100644 --- a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py +++ b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py @@ -122,6 +122,8 @@ def benchmark_flash_decoding_attention( mid_output_lse = torch.empty( size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device ) + exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device) + max_logits = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device) if provider == "vllm_paged_decoding_attention": alibi_slopes = None @@ -166,7 +168,8 @@ def benchmark_flash_decoding_attention( BLOCK_SIZE, max_seq_len_across_batch, mid_output, - mid_output_lse, + exp_sums, + max_logits, alibi_slopes, sm_scale, ) diff --git a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu index bcea786fe9dd..0845dd5673dd 100644 --- a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu +++ b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu @@ -14,6 +14,7 @@ #include "attention/attention_utils.h" #define WARP_SIZE 32 +#define PARTITION_SIZE 512 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -60,7 +61,7 @@ using namespace colossalAI::cuda::attention; // We only support head size of { 64, 128, 256 } // models like Phi-2, whose head size is 80, is not supported right now template -__global__ void flash_decoding_attention_kernel( +__global__ void flash_decoding_attention_kernel_v1( scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] @@ -252,25 +253,25 @@ __global__ void flash_decoding_attention_kernel( } } -#define LAUNCH_FLASH_DECODING_ATTENTION_V1(HEAD_SIZE) \ - cudaFuncSetAttribute( \ - ((void*)flash_decoding_attention_kernel), \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - flash_decoding_attention_kernel \ - <<>>( \ - reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(query.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - context_lens.data_ptr(), \ - block_tables.data_ptr(), \ - alibi_slopes_ptr, \ - max_context_len, \ - num_kv_heads, \ - scale, \ - max_num_blocks_per_seq, \ - q_stride, \ - kv_block_stride, \ +#define LAUNCH_FLASH_DECODING_ATTENTION_V1(HEAD_SIZE) \ + cudaFuncSetAttribute( \ + ((void*)flash_decoding_attention_kernel_v1), \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + flash_decoding_attention_kernel_v1 \ + <<>>( \ + reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + context_lens.data_ptr(), \ + block_tables.data_ptr(), \ + alibi_slopes_ptr, \ + max_context_len, \ + num_kv_heads, \ + scale, \ + max_num_blocks_per_seq, \ + q_stride, \ + kv_block_stride, \ kv_head_stride); template< @@ -291,8 +292,10 @@ void flash_decoding_attention_v1_launcher( int num_tokens = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); int q_stride = query.stride(0); + + int max_num_blocks_per_seq = block_tables.size(1); + int num_kv_heads = key_cache.size(1); int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); @@ -348,24 +351,477 @@ void flash_decoding_attention_v1_launcher( scale, \ alibi_slopes); + +template +__global__ void flash_decoding_attention_kernel_v2( + scalar_t* __restrict__ out, // [num_tokens, num_heads, max_num_partitions, head_size] + float* __restrict__ exp_sums, // [num_tokens, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_tokens, num_heads, max_num_partitions] + const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const int* __restrict__ context_lens, // [num_tokens] + const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq] + const float* __restrict__ alibi_slopes, // [num_heads] + const int max_seq_len, + const int num_kv_heads, + const float scale, + const int max_num_blocks_per_seq, + const int q_stride, // num_heads * head_size + const int tmp_stride, // num_heads * max_num_partitions + const int kv_block_stride, + const int kv_head_stride) { + const int partition_idx = blockIdx.z; + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int thread_idx = threadIdx.x; + const int lane = thread_idx % WARP_SIZE; + const int warp_idx = thread_idx / WARP_SIZE; + const int max_num_partitions = gridDim.z; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + constexpr int x = sizeof(float4) / sizeof(scalar_t); + constexpr int Q_SHARED_SIZE = HEAD_SIZE / x; + // here thread_group does not determine the number of threads responsible for a key + // but only the VEC_SIZE of each thread + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), x); + constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE; + constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE); + constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN; + constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN; + constexpr int NUM_THREADS_PER_X = x / VEC_SIZE; + constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE); + constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE; + constexpr int NUM_BLOCKS_PER_PARTITION = PARTITION_SIZE / BLOCK_SIZE; + + using KVecT = typename VecTypeTrait::Type; + using VVecT = typename VecTypeTrait::Type; + using KQuantVecT = typename VecTypeTrait::Type; + using VQuantVecT = typename VecTypeTrait::Type; + using LVecT = typename VecTypeTrait::Type; + using FloatVecT = typename FloatVecTypeTrait::Type; + + const int context_len = context_lens[seq_idx]; + + if (partition_idx * PARTITION_SIZE >= context_len) { + return; + } + + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const int thread_group_offset = lane % NUM_THREADS_PER_X; + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = partition_idx * NUM_BLOCKS_PER_PARTITION; + const int end_block_idx = MIN(start_block_idx + NUM_BLOCKS_PER_PARTITION, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int shared_memory_offset = DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); + + __shared__ float4 q_shared[Q_SHARED_SIZE]; + __shared__ float red_shared_mem[2 * NUM_WARPS]; + extern __shared__ char shared_mem[]; + int* block_table_shared = reinterpret_cast(shared_mem); + float* logits = reinterpret_cast(shared_mem + shared_memory_offset); + float* out_shared_mem = reinterpret_cast(shared_mem + shared_memory_offset); + float qk_max = -FLT_MAX; + + const float4* q_ptr = reinterpret_cast(q + seq_idx * q_stride + head_idx * HEAD_SIZE); + #pragma unroll + for (int idx = thread_idx; idx < Q_SHARED_SIZE; idx += blockDim.x) { + q_shared[idx] = q_ptr[idx]; + } + + #pragma unroll + for (int idx = thread_idx; idx < max_num_blocks_per_seq; idx += blockDim.x) { + block_table_shared[idx] = block_table[idx]; + } + + __syncthreads(); + + scalar_t* q_shared_ptr = reinterpret_cast(q_shared); + // each warp access a whole block + + KVecT q_vecs[NUM_VECS_PER_THREAD]; + #pragma unroll + for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) { + const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS; + const int offset1 = idx % NUM_THREADS_PER_X; + q_vecs[i] = *reinterpret_cast(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE); + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table_shared[block_idx]); + + KVecT k_vecs[NUM_VECS_PER_THREAD]; + + #pragma unroll + for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) { + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + i * x; + #pragma unroll + for (int idx = lane, j = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, j += 1) { + const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS; + const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS; + const int offset2 = idx % NUM_THREADS_PER_X; + k_vecs[j] = CastFunctor()(*reinterpret_cast(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE)); + } + + float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + + if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) { + const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // there exists a __syncthreads within this function + qk_max = block_max(red_shared_mem, qk_max); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + + exp_sum = block_sum(&red_shared_mem[NUM_WARPS], exp_sum); + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + if (thread_idx == 0) { + float* max_logits_ptr = max_logits + seq_idx * tmp_stride + + head_idx * max_num_partitions + + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * tmp_stride + + head_idx * max_num_partitions + + partition_idx; + *max_logits_ptr = qk_max; + *exp_sums_ptr = exp_sum; + } + + FloatVecT accs[NUM_ROUNDS_PER_TOKEN]; + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + zero(accs[i]); + } + + VVecT zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table_shared[block_idx]); + scalar_t logit; + + #pragma unroll + for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) { + const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + idx * VEC_SIZE; + + VVecT v_vecs[NUM_ROUNDS_PER_TOKEN]; + + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + v_vecs[i] = CastFunctor()(*((reinterpret_cast(v_ptr) + i * WARP_SIZE))); + } + + if (token_idx >= context_len) { + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + v_vecs[i] = zero_value; + } + } + + logit = CastFunctor()(logits[token_idx - start_token_idx]); + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + accs[i] = TernaryOpFunctor()(logit, v_vecs[i], accs[i]); + } + } + } + + // must insert a sync since both logits and out_shared_mem occupy the same buffer space + __syncthreads(); + + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + block_sum(out_shared_mem, accs[i]); + } + + scalar_t* out_ptr = out + seq_idx * q_stride * max_num_partitions + + head_idx * HEAD_SIZE * max_num_partitions + + partition_idx * HEAD_SIZE; + LVecT out_reg; + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + if (thread_idx < NUM_THREADS_PER_TOKEN) { + out_reg = CastFunctor()(accs[i]); + (reinterpret_cast(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg; + } + } +} + +template +__global__ void flash_decoding_reduce_kernel( + scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size] + float* __restrict__ exp_sums, // [num_tokens, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_tokens, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_tokens] + const int out_stride, + const int tmp_stride, + const int max_num_partitions) { + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + + extern __shared__ char shared_mem[]; + __shared__ float red_smem[2 * NUM_WARPS]; + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + seq_idx * tmp_stride + + head_idx * max_num_partitions; + + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float tmp_max_logit = max_logits_ptr[i]; + shared_max_logits[i] = tmp_max_logit; + max_logit = fmaxf(max_logit, tmp_max_logit); + } + + __syncthreads(); + + max_logit = block_max(red_smem, max_logit); + + float* shared_exp_sums = reinterpret_cast(shared_mem + num_partitions * sizeof(float)); + const float* exp_sums_ptr = exp_sums + seq_idx * tmp_stride + + head_idx * max_num_partitions; + + float global_exp_sum = 0.f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float tmp_max_logit = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(tmp_max_logit - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + + __syncthreads(); + + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.f, global_exp_sum + 1e-6f); + + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * out_stride * max_num_partitions + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * out_stride + head_idx * HEAD_SIZE; + + #pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.f; + for (int j = 0; j < num_partitions; j++) { + acc += CastFunctor()(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + } + out_ptr[i] = CastFunctor()(acc); + } +} + + +#define LAUNCH_FLASH_DECODING_ATTENTION_V2(HEAD_SIZE) \ + cudaFuncSetAttribute( \ + ((void*)flash_decoding_attention_kernel_v2), \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + flash_decoding_attention_kernel_v2 \ + <<>>( \ + reinterpret_cast(tmp_out.data_ptr()), \ + reinterpret_cast(exp_sums.data_ptr()), \ + reinterpret_cast(max_logits.data_ptr()), \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + reinterpret_cast(context_lens.data_ptr()), \ + reinterpret_cast(block_tables.data_ptr()), \ + alibi_slopes_ptr, \ + max_context_len, \ + num_kv_heads, \ + scale, \ + max_num_blocks_per_seq, \ + q_stride, \ + tmp_stride, \ + kv_block_stride, \ + kv_head_stride); \ + cudaFuncSetAttribute( \ + ((void*)flash_decoding_reduce_kernel), \ + cudaFuncAttributeMaxDynamicSharedMemorySize, reduce_shared_mem_size); \ + flash_decoding_reduce_kernel \ + <<>>( \ + reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(exp_sums.data_ptr()), \ + reinterpret_cast(max_logits.data_ptr()), \ + reinterpret_cast(tmp_out.data_ptr()), \ + reinterpret_cast(context_lens.data_ptr()), \ + q_stride, \ + tmp_stride, \ + max_num_partitions); + + +template< + typename T, + typename CACHE_T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void flash_decoding_attention_v2_launcher( + torch::Tensor& out, // [num_tokens, num_heads, head_size] + torch::Tensor& exp_sums, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& context_lens, // [num_tokens] + torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] + int max_context_len, + float scale, + const c10::optional& alibi_slopes) { + int num_tokens = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int q_stride = query.stride(0); + int tmp_stride = exp_sums.stride(0); + + int max_num_blocks_per_seq = block_tables.size(1); + + int num_kv_heads = key_cache.size(1); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + const int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((head_size / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(T)); + const int NUM_VECS_PER_TOKEN = head_size / VEC_SIZE; + const int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE); + + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float); + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); + + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + dim3 grid(num_heads, num_tokens, max_num_partitions); + dim3 block(NUM_THREADS); + + dim3 reduce_grid(num_heads, num_tokens); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. + case 64: + LAUNCH_FLASH_DECODING_ATTENTION_V2(64); + break; + case 128: + LAUNCH_FLASH_DECODING_ATTENTION_V2(128); + break; + case 256: + LAUNCH_FLASH_DECODING_ATTENTION_V2(256); + break; + default: + AT_ERROR("head size must be 64, 128, 256"); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE) \ + flash_decoding_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + context_lens, \ + block_tables, \ + max_context_len, \ + scale, \ + alibi_slopes); + // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T) \ +#define CALL_LAUNCHER_BLOCK_SIZE(Version, T, CACHE_T) \ switch (block_size) { \ case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8); \ + CALL_##Version##_LAUNCHER(T, CACHE_T, 8); \ break; \ case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16); \ + CALL_##Version##_LAUNCHER(T, CACHE_T, 16); \ break; \ case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32); \ + CALL_##Version##_LAUNCHER(T, CACHE_T, 32); \ break; \ default: \ AT_ERROR("block size must be 8, 16, 32"); \ break; \ } +#define CALL_LAUNCHER_DTYPE(Version) \ + if(key_cache.scalar_type() == at::ScalarType::Byte) \ + { \ + switch (query.scalar_type()) { \ + case at::ScalarType::Float: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, float, uint8_t); \ + break; \ + case at::ScalarType::Half: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, half, uint8_t); \ + break; \ + case at::ScalarType::BFloat16: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, __nv_bfloat16, uint8_t); \ + break; \ + } \ + } \ + else \ + { \ + switch (query.scalar_type()) { \ + case at::ScalarType::Float: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, float, float); \ + break; \ + case at::ScalarType::Half: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, half, half); \ + break; \ + case at::ScalarType::BFloat16: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, __nv_bfloat16, __nv_bfloat16); \ + break; \ + } \ + } + void flash_decoding_attention( torch::Tensor& out, // [num_tokens, num_heads, head_size] torch::Tensor& query, // [num_tokens, num_heads, head_size] @@ -376,41 +832,27 @@ void flash_decoding_attention( int block_size, int max_context_len, torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] - torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& exp_sums, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_tokens, num_heads, max_num_partitions] const c10::optional& alibi_slopes, float scale) { - if(key_cache.scalar_type() == at::ScalarType::Byte) - { - switch (query.scalar_type()) { - case at::ScalarType::Float: - CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t); - break; - case at::ScalarType::Half: - CALL_V1_LAUNCHER_BLOCK_SIZE(half, uint8_t); - break; - case at::ScalarType::BFloat16: - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t); - break; - } - } - else - { - switch (query.scalar_type()) { - case at::ScalarType::Float: - CALL_V1_LAUNCHER_BLOCK_SIZE(float, float); - break; - case at::ScalarType::Half: - CALL_V1_LAUNCHER_BLOCK_SIZE(half, half); - break; - case at::ScalarType::BFloat16: - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16); - break; - } + int num_tokens = query.size(0); + int num_heads = query.size(1); + + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + // TODO(luoxiang): Need to be tuned + bool use_v1 = max_context_len <= 8192 && (max_num_partitions == 1 || num_tokens * num_heads > 512); + + if (use_v1) { + CALL_LAUNCHER_DTYPE(V1); + } else { + CALL_LAUNCHER_DTYPE(V2); } } #undef LAUNCH_FLASH_DECODING_ATTENTION_V1 -#undef CALL_V1_LAUNCHER -#undef CALL_V1_LAUNCHER_BLOCK_SIZE +#undef CALL_LAUNCHER +#undef CALL_LAUNCHER_BLOCK_SIZE +#undef CALL_LAUNCHER_DTYPE diff --git a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu index 68b47c7e9f18..4f96c7c42c1f 100644 --- a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -24,6 +24,8 @@ __device__ void apply_emb_rotary_compute( BinaryOpFunctor mul; BinaryOpFunctor sub; BinaryOpFunctor add; + CastFunctor t2mt; + CastFunctor mt2t; T x[VecSize]; T y[VecSize]; @@ -44,10 +46,10 @@ __device__ void apply_emb_rotary_compute( #pragma unroll for (int j = 0; j < VecSize; j++) { - out_x[j] = CastFunctor()(sub(mul(CastFunctor()(x[j]), cos_ptr[j * 32 + shard_offset]), - mul(CastFunctor()(y[j]), sin_ptr[j * 32 + shard_offset]))); - out_y[j] = CastFunctor()(add(mul(CastFunctor()(y[j]), cos_ptr[j * 32 + shard_offset]), - mul(CastFunctor()(x[j]), sin_ptr[j * 32 + shard_offset]))); + out_x[j] = mt2t(sub(mul(t2mt(x[j]), cos_ptr[j * 32 + shard_offset]), + mul(t2mt(y[j]), sin_ptr[j * 32 + shard_offset]))); + out_y[j] = mt2t(add(mul(t2mt(y[j]), cos_ptr[j * 32 + shard_offset]), + mul(t2mt(x[j]), sin_ptr[j * 32 + shard_offset]))); } copy(out_x, src + addr_offset); diff --git a/extensions/pybind/inference/inference.cpp b/extensions/pybind/inference/inference.cpp index e0fac00bd28d..b9467391188a 100644 --- a/extensions/pybind/inference/inference.cpp +++ b/extensions/pybind/inference/inference.cpp @@ -72,7 +72,8 @@ void flash_decoding_attention( int block_size, int max_context_len, torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] - torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& exp_sums, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_tokens, num_heads, max_num_partitions] const c10::optional& alibi_slopes, float scale); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py index 80a5d067b82b..bf45d6fddda5 100644 --- a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py @@ -20,6 +20,7 @@ ) q_len = 1 +PARTITION_SIZE = 512 def prepare_data( @@ -57,7 +58,7 @@ def numpy_allclose(x, y, rtol, atol): @pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) @pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) -@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32]) +@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32, 256, 512]) @pytest.mark.parametrize("HEAD_SIZE", [64, 128]) @pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) @pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) @@ -76,81 +77,86 @@ def test_flash_decoding_attention( MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ device = get_current_device() - if use_alibi_slopes: - alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device) - else: - alibi_slopes = None - - q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( - BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device - ) - - k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( - k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device - ) + try: + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device) + else: + alibi_slopes = None - block_tables = block_tables.to(device=device) - max_seq_len_across_batch = kv_seq_lengths.max().item() - kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE - output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) - sm_scale = 1.0 / (HEAD_SIZE**0.5) + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( + BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device + ) - k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) - v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) - torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) - if use_alibi_slopes: - alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device) - torch_padding_mask = torch_padding_mask + alibi_mask + block_tables = block_tables.to(device=device) + max_seq_len_across_batch = kv_seq_lengths.max().item() + kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE + output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) + sm_scale = 1.0 / (HEAD_SIZE**0.5) - if len(torch_padding_mask.size()) == 4: - torch_padding_mask = torch_padding_mask[:, :, -1:, :] - else: - torch_padding_mask = torch_padding_mask[:, -1:, :] + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) - mid_output = torch.empty( - size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device - ) - mid_output_lse = torch.empty( - size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device - ) + if use_alibi_slopes: + alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device) + torch_padding_mask = torch_padding_mask + alibi_mask - if dtype == torch.float16: - rtol = 1e-3 - atol = 1e-3 + if len(torch_padding_mask.size()) == 4: + torch_padding_mask = torch_padding_mask[:, :, -1:, :] + else: + torch_padding_mask = torch_padding_mask[:, -1:, :] - high_precision_q = q.to(torch.float32) - high_precision_k_torch = k_torch.to(torch.float32) - high_precision_v_torch = v_torch.to(torch.float32) - out_ref = torch_attn_ref( - high_precision_q, - high_precision_k_torch, - high_precision_v_torch, - torch_padding_mask, - BATCH_SIZE, - q_len, - max_seq_len_across_batch, - NUM_ATTN_HEADS, - NUM_KV_HEADS, - HEAD_SIZE, - ).to(torch.float16) + mid_output = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device + ) + exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device) + max_logits = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device + ) - else: - rtol = 1e-5 - atol = 1e-7 + if dtype == torch.float16: + rtol = 1e-3 + atol = 1e-3 + + high_precision_q = q.to(torch.float32) + high_precision_k_torch = k_torch.to(torch.float32) + high_precision_v_torch = v_torch.to(torch.float32) + out_ref = torch_attn_ref( + high_precision_q, + high_precision_k_torch, + high_precision_v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ).to(torch.float16) - out_ref = torch_attn_ref( - q, - k_torch, - v_torch, - torch_padding_mask, - BATCH_SIZE, - q_len, - max_seq_len_across_batch, - NUM_ATTN_HEADS, - NUM_KV_HEADS, - HEAD_SIZE, - ) + else: + rtol = 1e-5 + atol = 1e-7 + + out_ref = torch_attn_ref( + q, + k_torch, + v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ) + + except torch.cuda.OutOfMemoryError: + pytest.skip("Required GPU memory is larger than capacity.") inference_ops.flash_decoding_attention( output, @@ -162,7 +168,8 @@ def test_flash_decoding_attention( BLOCK_SIZE, max_seq_len_across_batch, mid_output, - mid_output_lse, + exp_sums, + max_logits, alibi_slopes, sm_scale, )