Skip to content

Commit

Permalink
bugfix: fix the compilation issue of pip wheels (#115)
Browse files Browse the repository at this point in the history
This PR fixes #113, which is because #69 changed the
`BatchPrefillWithPagedKVCacheWrapperDispatched` signature, and
`flashinfer_decl.h` was not updated accordingly.

Also fixes some tiny format issues in #111.
  • Loading branch information
yzh119 authored Feb 16, 2024
1 parent 1306d11 commit d4146fb
Show file tree
Hide file tree
Showing 16 changed files with 44 additions and 30 deletions.
5 changes: 3 additions & 2 deletions include/flashinfer/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
return BatchPrefillWithRaggedKVCacheWrapperDispatched<
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size,
num_kv_heads, rope_scale, rope_theta, stream);
handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr,
/*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads,
rope_scale, rope_theta, stream);
})})})})})});
return cudaSuccess;
}
Expand Down
1 change: 1 addition & 0 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
&handler_, static_cast<c_type*>(q.data_ptr()),
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
/*q_rope_position=*/nullptr, /*k_rope_pos_offset=*/nullptr,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
num_kv_heads, rope_scale, rope_theta,
Expand Down
18 changes: 9 additions & 9 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v_a.scalar_type(), c_type, [&] {
cudaError_t status =
MergeState(static_cast<c_type*>(v_a.data_ptr()), static_cast<float*>(s_a.data_ptr()),
static_cast<c_type*>(v_b.data_ptr()), static_cast<float*>(s_b.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()),
static_cast<float*>(s_merged.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
cudaError_t status = MergeState(
static_cast<c_type*>(v_a.data_ptr()), static_cast<float*>(s_a.data_ptr()),
static_cast<c_type*>(v_b.data_ptr()), static_cast<float*>(s_b.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()), static_cast<float*>(s_merged.data_ptr()),
seq_len, num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeState kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down Expand Up @@ -80,10 +80,10 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] {
cudaError_t status =
MergeStateInPlace(static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_other.data_ptr()),
static_cast<float*>(s_other.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
cudaError_t status = MergeStateInPlace(
static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_other.data_ptr()), static_cast<float*>(s_other.data_ptr()), seq_len,
num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down
25 changes: 13 additions & 12 deletions python/csrc/flashinfer_decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@
template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \
PageStorage::kIndices, LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, \
CAUSAL, T, T, int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, \
int32_t* q_rope_position, \
paged_kv_t<PageStorage::kIndices, LAYOUT, T, int32_t> paged_kv, T* o, \
float* lse, float rope_scale, float rope_theta, cudaStream_t stream); \
}

#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \
LAYOUT, ROTARY_MODE) \
namespace flashinfer { \
template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \
GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \
BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \
T* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, float rope_scale, \
float rope_theta, cudaStream_t stream); \
#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \
LAYOUT, ROTARY_MODE) \
namespace flashinfer { \
template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \
GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \
BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \
int32_t* q_rope_position, int32_t* k_rope_pos_offset, T* o, float* lse, uint32_t batch_size, \
uint32_t num_kv_heads, float rope_scale, float rope_theta, cudaStream_t stream); \
}

#define INST_SinglePrefill(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \
Expand All @@ -56,15 +57,15 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMod
typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
const uint32_t num_kv_heads, const float rope_scale, const float rope_theta,
cudaStream_t stream);
IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
const uint32_t batch_size, const uint32_t num_kv_heads, const float rope_scale,
const float rope_theta, cudaStream_t stream);

template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
float rope_scale, float rope_theta, cudaStream_t stream);

Expand Down
7 changes: 4 additions & 3 deletions python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
num_heads, page_size, head_dim, batch_size, static_cast<c_type*>(kv_data.data_ptr()),
static_cast<int32_t*>(kv_indices.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
cudaError_t status = AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
cudaError_t status =
AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"AppendPagedKVCache failed with error: ", cudaGetErrorString(status));
return true;
Expand Down
2 changes: 1 addition & 1 deletion python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

#include "generated/dispatch.inc"

Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from .decode import (
single_decode_with_kv_cache,
batch_decode_with_padded_kv_cache,
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import Optional
import torch
Expand Down
7 changes: 4 additions & 3 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import Optional, Union
import torch
Expand Down Expand Up @@ -477,9 +478,9 @@ def begin_forward(
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info
empty_data = torch.empty(
0,
dtype=getattr(torch, data_type)
if isinstance(data_type, str)
else data_type,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self._wrapper.begin_forward(
self._workspace_buffer,
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch

try:
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import Optional
import torch
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch


Expand Down
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import pathlib
import os
import re
Expand Down
1 change: 1 addition & 0 deletions python/tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy
import pytest
import torch
Expand Down
1 change: 1 addition & 0 deletions python/tests/test_batch_prefill_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy
import pytest
import torch
Expand Down
1 change: 1 addition & 0 deletions python/tests/test_shared_prefix_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy
import pytest
import torch
Expand Down

0 comments on commit d4146fb

Please sign in to comment.