diff --git a/include/flashinfer/wrapper.cuh b/include/flashinfer/wrapper.cuh index 64cd0956..dd508127 100644 --- a/include/flashinfer/wrapper.cuh +++ b/include/flashinfer/wrapper.cuh @@ -207,8 +207,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( return BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size, - num_kv_heads, rope_scale, rope_theta, stream); + handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr, + /*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads, + rope_scale, rope_theta, stream); })})})})})}); return cudaSuccess; } diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 18dc6a91..bea6b9bc 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -216,6 +216,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( &handler_, static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(kv_indptr.data_ptr()), + /*q_rope_position=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, batch_size, num_kv_heads, rope_scale, rope_theta, diff --git a/python/csrc/cascade.cu b/python/csrc/cascade.cu index 76fcbf87..4e1e379e 100644 --- a/python/csrc/cascade.cu +++ b/python/csrc/cascade.cu @@ -44,11 +44,11 @@ std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, tor auto s_merged = torch::empty({seq_len, num_heads}, s_a.options()); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v_a.scalar_type(), c_type, [&] { - cudaError_t status = - MergeState(static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), - static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), - static_cast(v_merged.data_ptr()), - static_cast(s_merged.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream); + cudaError_t status = MergeState( + static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), + static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), + static_cast(v_merged.data_ptr()), static_cast(s_merged.data_ptr()), + seq_len, num_heads, head_dim, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status)); return true; @@ -80,10 +80,10 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] { - cudaError_t status = - MergeStateInPlace(static_cast(v.data_ptr()), static_cast(s.data_ptr()), - static_cast(v_other.data_ptr()), - static_cast(s_other.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream); + cudaError_t status = MergeStateInPlace( + static_cast(v.data_ptr()), static_cast(s.data_ptr()), + static_cast(v_other.data_ptr()), static_cast(s_other.data_ptr()), seq_len, + num_heads, head_dim, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/flashinfer_decl.h b/python/csrc/flashinfer_decl.h index a9144616..5ae1c251 100644 --- a/python/csrc/flashinfer_decl.h +++ b/python/csrc/flashinfer_decl.h @@ -24,18 +24,19 @@ template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \ PageStorage::kIndices, LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, \ CAUSAL, T, T, int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, \ + int32_t* q_rope_position, \ paged_kv_t paged_kv, T* o, \ float* lse, float rope_scale, float rope_theta, cudaStream_t stream); \ } -#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ - LAYOUT, ROTARY_MODE) \ - namespace flashinfer { \ - template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \ - GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \ - BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \ - T* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, float rope_scale, \ - float rope_theta, cudaStream_t stream); \ +#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ + LAYOUT, ROTARY_MODE) \ + namespace flashinfer { \ + template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \ + GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \ + BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \ + int32_t* q_rope_position, int32_t* k_rope_pos_offset, T* o, float* lse, uint32_t batch_size, \ + uint32_t num_kv_heads, float rope_scale, float rope_theta, cudaStream_t stream); \ } #define INST_SinglePrefill(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \ @@ -56,15 +57,15 @@ template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size, - const uint32_t num_kv_heads, const float rope_scale, const float rope_theta, - cudaStream_t stream); + IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, + const uint32_t batch_size, const uint32_t num_kv_heads, const float rope_scale, + const float rope_theta, cudaStream_t stream); template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, paged_kv_t paged_kv, DTypeOut* o, float* lse, float rope_scale, float rope_theta, cudaStream_t stream); diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 391576cd..b71751cf 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -73,9 +73,10 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, num_heads, page_size, head_dim, batch_size, static_cast(kv_data.data_ptr()), static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), static_cast(kv_last_page_len.data_ptr())); - cudaError_t status = AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), - static_cast(append_value.data_ptr()), - static_cast(append_indptr.data_ptr()), torch_current_stream); + cudaError_t status = + AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), + static_cast(append_value.data_ptr()), + static_cast(append_indptr.data_ptr()), torch_current_stream); TORCH_CHECK(status == cudaSuccess, "AppendPagedKVCache failed with error: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 780fe89a..54c3aba8 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ #pragma once -#include #include +#include #include "generated/dispatch.inc" diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 53b7688e..e198f48c 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from .decode import ( single_decode_with_kv_cache, batch_decode_with_padded_kv_cache, diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 90e4d25e..e77955ba 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import math from typing import Optional import torch diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index dc67024b..0139c76c 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import math from typing import Optional, Union import torch @@ -477,9 +478,9 @@ def begin_forward( # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info empty_data = torch.empty( 0, - dtype=getattr(torch, data_type) - if isinstance(data_type, str) - else data_type, + dtype=( + getattr(torch, data_type) if isinstance(data_type, str) else data_type + ), ) self._wrapper.begin_forward( self._workspace_buffer, diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index d68db529..112ed30a 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import torch try: diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 0e617a45..8799be84 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import math from typing import Optional import torch diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 857399a7..14beb9be 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import torch diff --git a/python/setup.py b/python/setup.py index f594baa6..8f38e53c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import pathlib import os import re diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index 56ae1458..e396efd1 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import numpy import pytest import torch diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index 4fd0967f..02851bac 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import numpy import pytest import torch diff --git a/python/tests/test_shared_prefix_kernels.py b/python/tests/test_shared_prefix_kernels.py index 092da4f8..f975c075 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/python/tests/test_shared_prefix_kernels.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import numpy import pytest import torch