Skip to content

Commit

Permalink
Support RoPE position info in batch prefill/decode kernels (#69)
Browse files Browse the repository at this point in the history
This PR adds q/k position information to batch prefill/decode kernels.
More specifically, the kernel now accepts two additional arrays:
* `q_rope_position` with shape `(total_q_len,)`, denoting the
in-sequence position of each position in the input q.
* `k_rope_pos_offset` with shape `(num_sequence,)`, denoting the start
position of each sequence in k.

These two arrays helps on-the-fly calculate RoPE in multi-level cases.

Tests `test_batch_prefill` and `test_batch_decode` can pass. Performance
is not validated yet. Per discussion with Zihao, this change is not very
likely to incur significant perf regression.
  • Loading branch information
MasterJH5574 authored Feb 1, 2024
1 parent c55cd60 commit a389ed4
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 127 deletions.
49 changes: 23 additions & 26 deletions include/flashinfer/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ template <bool partition_kv, RotaryMode rotary_mode, uint32_t num_stages_smem,
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeIn* __restrict__ q, IdType* __restrict__ q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
Expand All @@ -520,6 +521,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
: 0;
const uint32_t seq_len =
partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len;
const uint32_t mapped_batch_idx =
partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx;

extern __shared__ uint8_t smem[];
DTypeIn* k_smem = (DTypeIn*)smem;
Expand All @@ -541,23 +544,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}
// apply rotary embedding to q matrix
if constexpr (partition_kv) {
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim,
freq, seq_len - 1);
} else {
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, seq_len - 1);
}
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq,
q_rope_position == nullptr ? (seq_len - 1) : q_rope_position[mapped_batch_idx]);
} else {
// do not apply rotary embedding to q matrix
if constexpr (partition_kv) {
q_vec.cast_load(
q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim +
tx * vec_size);
} else {
q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
}
q_vec.cast_load(q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
}
block.sync();

Expand Down Expand Up @@ -627,7 +619,9 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
block.sync();
compute_qk<rotary_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
freq,
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, sm_scale, s, st);
block.sync();

Expand Down Expand Up @@ -1120,7 +1114,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation(
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
RotaryMode ROTARY_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeIn* q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeIn* q, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse,
float rope_scale, float rope_theta, cudaStream_t stream) {
const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM));
Expand Down Expand Up @@ -1153,6 +1148,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
(void*)&q_rope_position,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
Expand All @@ -1171,6 +1167,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
(void*)&q_rope_position,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
Expand Down Expand Up @@ -1212,7 +1209,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCache(
DTypeIn* q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeIn* q, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse,
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
Expand All @@ -1228,13 +1226,12 @@ cudaError_t BatchDecodeWithPagedKVCache(

DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
kv_layout, ROTARY_MODE, DTypeIn, DTypeOut,
IdType>(
q, paged_kv, kv_partition_info, o, tmp, lse, rope_scale, rope_theta, stream);
})})});
{DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
return BatchDecodeWithPagedKVCacheDispatched<
GROUP_SIZE, HEAD_DIM, page_storage, kv_layout, ROTARY_MODE, DTypeIn,
DTypeOut, IdType>(q, q_rope_position, paged_kv, kv_partition_info, o,
tmp, lse, rope_scale, rope_theta, stream);
})})});

return cudaSuccess;
}
Expand Down
20 changes: 15 additions & 5 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ struct paged_kv_t {
IdType* indptr;
// [batch_size] The offset of the last page for each request in the batch
IdType* last_page_len;
// [batch_size] The start position of each request in the batch.
IdType* rope_pos_offset;

/*!
* \brief Construct an empty paged key-value cache
Expand All @@ -101,7 +103,8 @@ struct paged_kv_t {
indices(nullptr),
ptrs(nullptr),
indptr(nullptr),
last_page_len(nullptr) {}
last_page_len(nullptr),
rope_pos_offset(nullptr) {}

/*!
* \brief Construct a paged key-value cache
Expand All @@ -113,20 +116,23 @@ struct paged_kv_t {
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
uint32_t head_dim, uint32_t batch_size,
DType* data, IdType* indices, IdType* indptr,
IdType* last_page_len)
IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
data(data),
indices(indices),
indptr(indptr),
last_page_len(last_page_len) {}
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {}

/*!
* \brief Construct a paged key-value cache
Expand All @@ -137,18 +143,22 @@ struct paged_kv_t {
* \param ptrs The array of pointers to each active page
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
uint32_t head_dim, uint32_t batch_size,
DType** ptrs, IdType* indptr,
IdType* last_page_len)
IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
ptrs(ptrs),
indptr(indptr) {}
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {}

/*!
* \brief Compute the offset of k element in the allocated buffer.
Expand Down
Loading

0 comments on commit a389ed4

Please sign in to comment.