Skip to content

Commit

Permalink
perf: multiple q by sm_scale in decode kernels (#144)
Browse files Browse the repository at this point in the history
The same optimization was used in our prefill attention kernels, this PR
applies this optimization to decode attention kernels.
  • Loading branch information
yzh119 authored Mar 1, 2024
1 parent 5f70697 commit 660c559
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ template <RotaryMode rotary_mode, uint32_t vec_size, uint32_t bdx, uint32_t tile
__device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx,
const vec_t<float, vec_size>& q_vec,
const vec_t<float, vec_size>& freq, uint32_t kv_idx_base,
uint32_t iter_base, uint32_t iter_bound, float sm_scale,
float* s, state_t<vec_size>& st) {
uint32_t iter_base, uint32_t iter_bound, float* s,
state_t<vec_size>& st) {
uint32_t tx = threadIdx.x, tz = threadIdx.z;
float m_prev = st.m;
#pragma unroll
Expand All @@ -86,7 +86,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
s[j] = 0.f;
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
s[j] += q_vec[i] * k_vec[i] * sm_scale;
s[j] += q_vec[i] * k_vec[i];
}
#pragma unroll
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
Expand Down Expand Up @@ -240,6 +240,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
// do not apply rotary embedding to q matrix
q_vec.cast_load(q + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
}
// multiple q_vec by sm_scale
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_vec[i] *= sm_scale;
}
block.sync();

uint32_t chunk_start = kv_chunk_idx * kv_chunk_size;
Expand Down Expand Up @@ -286,8 +291,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
block.sync();
compute_qk<rotary_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, sm_scale,
s, st_local);
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, s,
st_local);
block.sync();
// load k
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
Expand Down Expand Up @@ -385,6 +390,10 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
q_vec.cast_load(q + batch_idx * num_qo_heads * head_dim +
info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_vec[i] *= sm_scale;
}
block.sync();

// preload k tiles and v tiles
Expand Down Expand Up @@ -421,7 +430,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
block.sync();
compute_qk<rotary_mode, vec_size, bdx, bdy>(k_smem + (stage_idx * bdz + tz) * bdy * head_dim,
stage_idx, q_vec, freq, consumer_kv_idx_base,
iter * bdy * bdz, seq_len, sm_scale, s, st_local);
iter * bdy * bdz, seq_len, s, st_local);
block.sync();
// load k
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
Expand Down Expand Up @@ -551,6 +560,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
// do not apply rotary embedding to q matrix
q_vec.cast_load(q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_vec[i] *= sm_scale;
}
block.sync();

// preload k/v tiles
Expand Down Expand Up @@ -622,7 +635,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
freq,
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, sm_scale, s, st);
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, s, st);
block.sync();

#pragma unroll
Expand Down

0 comments on commit 660c559

Please sign in to comment.