Skip to content

Commit

Permalink
feat: enable head_dim=256 for attention kernels (#132)
Browse files Browse the repository at this point in the history
As mentioned in #130 , the kernels for `head_dim=256` are not compiled
by default, this PR expose these attention kernels to pip wheels and
adds unittests/benchmarks for `head_dim=256`.
  • Loading branch information
yzh119 authored Feb 25, 2024
1 parent a346b27 commit 0372acc
Show file tree
Hide file tree
Showing 19 changed files with 415 additions and 297 deletions.
5 changes: 3 additions & 2 deletions include/flashinfer/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ class BatchPrefillHandler {

template <typename IdType>
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads) {
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
Expand All @@ -197,7 +198,7 @@ class BatchPrefillHandler {
uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
std::vector<IdType> request_indices_h, tile_indices_h;
std::tie(num_frags_x_, num_qo_tiles_, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr, batch_size, gqa_group_size, stream_);
split_qo_indptr(qo_indptr, batch_size, gqa_group_size, head_dim, stream_);
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
request_indices_ =
allocator.aligned_alloc<void*>(sizeof(IdType) * request_indices_h.size(), 16);
Expand Down
14 changes: 6 additions & 8 deletions include/flashinfer/permuted_smem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,25 @@ struct smem_t {
template <uint32_t step_size>
static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset,
uint32_t step_idx) {
static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, "Unsupported step size");
if constexpr (step_size == 2) {
return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8;
} else if constexpr (step_size == 4) {
return (offset ^ 0x4) + (step_idx % 2 == 1) * 8;
} else if constexpr (step_size % 8 == 0) {
return offset + step_size;
} else {
// Note(Zihao): not implemented yet.
return 0;
// step_size % 8 == 0
return offset + step_size;
}
}

template <uint32_t step_size, uint32_t row_stride>
static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset) {
static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size");
if constexpr (step_size == 4) {
return (offset ^ 0x4) + step_size * row_stride;
} else if constexpr (step_size % 8 == 0) {
return offset + step_size * row_stride;
} else {
// NOTE(Zihao): not implemented yet.
return 0;
// step_size % 8 == 0
return offset + step_size * row_stride;
}
}

Expand Down
432 changes: 263 additions & 169 deletions include/flashinfer/prefill.cuh

Large diffs are not rendered by default.

92 changes: 41 additions & 51 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,40 +81,49 @@
__VA_ARGS__ \
}

#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
if (num_frags_x == 1) { \
constexpr size_t NUM_FRAGS_X = 1; \
__VA_ARGS__ \
} else if (num_frags_x == 2) { \
constexpr size_t NUM_FRAGS_X = 2; \
__VA_ARGS__ \
} else { \
std::cerr << "Unsupported num_frags_x: " << num_frags_x << std::endl; \
#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
if (num_frags_x == 1) { \
constexpr size_t NUM_FRAGS_X = 1; \
__VA_ARGS__ \
} else if (num_frags_x == 2) { \
constexpr size_t NUM_FRAGS_X = 2; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported num_frags_x: " << num_frags_x; \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
if (max_frags_z == 4) { \
constexpr size_t NUM_FRAGS_Z = 4; \
__VA_ARGS__ \
} else if (max_frags_z == 2) { \
constexpr size_t NUM_FRAGS_Z = 2; \
__VA_ARGS__ \
} else { \
std::cerr << "Unsupported max_frags_z: " << max_frags_z << std::endl; \
#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
if (max_frags_z >= 4) { \
constexpr size_t NUM_FRAGS_Z = 4; \
__VA_ARGS__ \
} else if (max_frags_z >= 2) { \
constexpr size_t NUM_FRAGS_Z = 2; \
__VA_ARGS__ \
} else if (max_frags_z >= 1) { \
constexpr size_t NUM_FRAGS_Z = 1; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported max_frags_z: " << max_frags_z; \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
} else if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else { \
std::cerr << "Unsupported group_size: " << group_size << std::endl; \
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
} else if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported group_size: " << group_size; \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \
Expand Down Expand Up @@ -169,25 +178,6 @@
} \
}

#define DISPATCH_HEAD_DIM_PREFILL(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported head_dim: " << head_dim; \
throw std::invalid_argument(err_msg.str()); \
} \
}

#define DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \
switch (rotary_mode) { \
case RotaryMode::kNone: { \
Expand Down Expand Up @@ -222,7 +212,7 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {

template <typename IdType>
std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_indptr(
IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size,
IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, uint32_t head_dim,
cudaStream_t stream = nullptr) {
constexpr uint32_t num_warps = 4;
std::vector<IdType> qo_indptr_h(batch_size + 1), request_indices, tile_indices;
Expand All @@ -235,7 +225,7 @@ std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_in

const uint32_t total_q_len = qo_indptr_h[batch_size];
const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size;
const uint32_t num_frags_x = avg_len_greater_than_64 ? 2 : 1;
const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1;
const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16;
uint32_t num_qo_tiles = 0;

Expand Down
6 changes: 3 additions & 3 deletions include/flashinfer/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4,
Expand All @@ -142,8 +142,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper(
return BatchPrefillWithPagedKVCacheWrapperDispatched<
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, paged_kv, o, lse, rope_scale, rope_theta,
stream);
handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, rope_scale,
rope_theta, stream);
})})})})});
return cudaSuccess;
}
Expand Down
30 changes: 14 additions & 16 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@

using namespace flashinfer;

void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer,
torch::Tensor qo_indptr,
unsigned int batch_size,
unsigned int num_qo_heads,
unsigned int num_kv_heads) {
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
Expand All @@ -37,9 +35,10 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor work
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads);
cudaError_t status =
handler_.BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
Expand Down Expand Up @@ -140,11 +139,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
}
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer,
torch::Tensor qo_indptr,
unsigned int batch_size,
unsigned int num_qo_heads,
unsigned int num_kv_heads) {
void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
Expand All @@ -158,9 +155,10 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor wor
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads);
cudaError_t status =
handler_.BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
Expand Down
6 changes: 4 additions & 2 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
return BatchPrefillWithPagedKVCachePyTorchWrapper(layout);
}
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim);
void EndForward();
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
Expand All @@ -101,7 +102,8 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
return BatchPrefillWithRaggedKVCachePyTorchWrapper(layout);
}
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim);
void EndForward();
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
Expand Down
7 changes: 6 additions & 1 deletion python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
... paged_kv_indices,
... paged_kv_last_page_len,
... num_qo_heads,
... num_kv_heads
... num_kv_heads,
... head_dim,
... )
>>> outputs = []
>>> for i in range(num_layers):
Expand Down Expand Up @@ -641,6 +642,7 @@ def begin_forward(
paged_kv_last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
):
r"""Create auxiliary data structures for shared-prefix batch prefill/append
attention for multiple forward calls within the same prefill/append step.
Expand All @@ -660,6 +662,8 @@ def begin_forward(
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim : int
The dimension of the heads.
Notes
-----
Expand All @@ -679,6 +683,7 @@ def begin_forward(
paged_kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
)

def end_forward(self):
Expand Down
26 changes: 22 additions & 4 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ class BatchPrefillWithPagedKVCacheWrapper:
... paged_kv_indices,
... paged_kv_last_page_len,
... num_qo_heads,
... num_kv_heads
... num_kv_heads,
... head_dim
... )
>>> outputs = []
>>> for i in range(num_layers):
Expand Down Expand Up @@ -365,6 +366,7 @@ def begin_forward(
paged_kv_last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
):
r"""Create auxiliary data structures for batch prefill/append attention for
multiple forward calls within the same prefill/append step.
Expand All @@ -384,6 +386,8 @@ def begin_forward(
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim : int
The dimension of the heads.
Notes
-----
Expand All @@ -401,7 +405,12 @@ def begin_forward(
self._paged_kv_indices = paged_kv_indices
self._paged_kv_last_page_len = paged_kv_last_page_len
self._wrapper.begin_forward(
self._workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads
self._workspace_buffer,
qo_indptr,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
)

def end_forward(self):
Expand Down Expand Up @@ -571,7 +580,8 @@ class BatchPrefillWithRaggedKVCacheWrapper:
... qo_indptr,
... kv_indptr,
... num_qo_heads,
... num_kv_heads
... num_kv_heads,
... head_dim
... )
>>> outputs = []
>>> for i in range(num_layers):
Expand Down Expand Up @@ -635,6 +645,7 @@ def begin_forward(
kv_indptr: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
):
r"""Create auxiliary data structures for batch prefill/append attention for
multiple forward calls within the same prefill/append step.
Expand All @@ -649,6 +660,8 @@ def begin_forward(
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim : int
The dimension of the heads.
Notes
-----
Expand All @@ -664,7 +677,12 @@ def begin_forward(
self._qo_indptr = qo_indptr
self._kv_indptr = kv_indptr
self._wrapper.begin_forward(
self._workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads
self._workspace_buffer,
qo_indptr,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
)

def end_forward(self):
Expand Down
Loading

0 comments on commit 0372acc

Please sign in to comment.