diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index e9189c63..dbab3c30 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -166,51 +166,15 @@ cudaError_t BatchDecodeWithPagedKVCache( * \note This wrapper function should be only called after we call BeginForward function in the * BatchDecodeHandler. */ -template -cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( - BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { - paged_kv_t new_paged_kv = paged_kv; - kv_partition_info_t kv_partition_info; - DTypeOut* tmp = handler->GetTempFloatBuffer(); - - if (handler->IsForwardStarted()) { - if (tmp != nullptr) { - // create auxiliary information for cooperative kernels - new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition(); - new_paged_kv.indptr = handler->GetNewIndPtr(); - new_paged_kv.last_page_len = handler->GetNewLastPageLen(); - kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition(); - kv_partition_info.chunk_indptr = handler->GetChunkIndPtr(); - kv_partition_info.batch_idx_map = handler->GetBatchIdxMap(); - kv_partition_info.chunk_start_pos = handler->GetChunkStartPos(); - kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition(); - } - } else { - std::ostringstream err_msg; - err_msg << "Please call BatchDecodeHandler's BeginForward() before calling " - "BatchDecodeWithPagedKVCacheWrapper()"; - throw std::runtime_error(err_msg.str()); - } - - return BatchDecodeWithPagedKVCacheDispatched( - q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, rope_theta, - stream); - return cudaSuccess; -} - -template cudaError_t BatchDecodeWithPagedKVCacheWrapper( BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { - const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); + float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); const uint32_t num_kv_heads = paged_kv.num_heads; if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; @@ -219,18 +183,42 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( throw std::invalid_argument(err_msg.str()); } - // DISPATCH_GQA_GROUP_SIZE( - // num_qo_heads / num_kv_heads, GROUP_SIZE, - // {DISPATCH_HEAD_DIM( - // paged_kv.head_dim, HEAD_DIM, - // {DISPATCH_POS_ENCODING_MODE( - // pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { - // return BatchDecodeWithPagedKVCacheWrapperDispatched< - // page_storage, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, DTypeIn, - // DTypeOut, IdType>(handler, q, q_offset, paged_kv, o, lse, sm_scale, - // rope_scale, - // rope_theta, stream); - // })})})}); + DISPATCH_GQA_GROUP_SIZE( + num_qo_heads / num_kv_heads, GROUP_SIZE, + {DISPATCH_HEAD_DIM( + paged_kv.head_dim, HEAD_DIM, + {DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { + paged_kv_t new_paged_kv = paged_kv; + kv_partition_info_t kv_partition_info; + DTypeOut* tmp = handler->GetTempFloatBuffer(); + + if (handler->IsForwardStarted()) { + if (tmp != nullptr) { + // create auxiliary information for cooperative kernels + new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition(); + new_paged_kv.indptr = handler->GetNewIndPtr(); + new_paged_kv.last_page_len = handler->GetNewLastPageLen(); + kv_partition_info.batch_size_before_partition = + handler->GetBatchSizeBeforePartition(); + kv_partition_info.chunk_indptr = handler->GetChunkIndPtr(); + kv_partition_info.batch_idx_map = handler->GetBatchIdxMap(); + kv_partition_info.chunk_start_pos = handler->GetChunkStartPos(); + kv_partition_info.seq_lens_before_partition = + handler->GetSeqLengthsBeforePartition(); + } + } else { + std::ostringstream err_msg; + err_msg << "Please call BatchDecodeHandler's BeginForward() before calling " + "BatchDecodeWithPagedKVCacheWrapper()"; + throw std::runtime_error(err_msg.str()); + } + + return BatchDecodeWithPagedKVCacheDispatched( + q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, + rope_theta, stream); + })})}); return cudaSuccess; } diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index 1bb2819c..112cbea4 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 97]) @pytest.mark.parametrize("qo_len", [37, 17]) -@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("page_size", [1, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) diff --git a/python/tests/test_shared_prefix_kernels.py b/python/tests/test_shared_prefix_kernels.py index 8f994c77..b0d149fe 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/python/tests/test_shared_prefix_kernels.py @@ -58,7 +58,7 @@ def test_batch_decode_with_shared_prefix_padded_kv_cache( @pytest.mark.parametrize("shared_kv_len", [54, 97, 1979]) @pytest.mark.parametrize("num_heads", [8, 16]) @pytest.mark.parametrize("head_dim", [128, 256]) -@pytest.mark.parametrize("page_size", [1, 4, 16]) +@pytest.mark.parametrize("page_size", [1, 16]) def test_batch_decode_with_shared_prefix_paged_kv_cache( batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim, page_size ): @@ -131,7 +131,7 @@ def test_batch_decode_with_shared_prefix_paged_kv_cache( @pytest.mark.parametrize("num_heads", [8, 16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("head_dim", [128, 256]) -@pytest.mark.parametrize("page_size", [1, 4, 16]) +@pytest.mark.parametrize("page_size", [1, 16]) def test_batch_prefill_with_shared_prefix_paged_kv_cache( batch_size, unique_kv_len, shared_kv_len, num_heads, causal, head_dim, page_size ):