Skip to content

Commit

Permalink
misc: ignore invalid configurations (#162)
Browse files Browse the repository at this point in the history
To speed up compilation a little bit.
  • Loading branch information
yzh119 authored Mar 7, 2024
1 parent 30fa584 commit 23c02ce
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ constexpr uint32_t warp_size = 32;

namespace {

template <typename DTypeQKAccum>
constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags_y,
uint32_t num_frags_z, uint32_t num_warps) {
return ((num_frags_y < 4) || (num_frags_y == 4 && num_frags_z % 2 == 1) ||
(num_frags_y > 4 && num_frags_y % 8 != 0));
(num_frags_y > 4 && num_frags_y % 8 != 0) ||
(num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256));
}

/*!
Expand Down Expand Up @@ -1571,8 +1573,9 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation(
// control num_frags_z for maximum warp occupancy
DISPATCH_NUM_FRAGS_Z(
min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, {
if constexpr (is_invalid_configuration(num_frags_x, num_frags_y,
num_frags_z, num_warps)) {
if constexpr (is_invalid_configuration<DTypeQKAccum>(
num_frags_x, num_frags_y, num_frags_z,
num_warps)) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg << "FlashInfer Internal Error: Invalid configuration : "
Expand Down Expand Up @@ -1683,7 +1686,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*

// control num_frags_z for maximum warp occupancy
DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, {
if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, num_warps)) {
if constexpr (is_invalid_configuration<DTypeQKAccum>(num_frags_x, num_frags_y, num_frags_z,
num_warps)) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x
Expand Down Expand Up @@ -1862,7 +1866,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
(max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps) / 2;

DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, {
if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, num_warps)) {
if constexpr (is_invalid_configuration<DTypeQKAccum>(num_frags_x, num_frags_y, num_frags_z,
num_warps)) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x
Expand Down Expand Up @@ -2066,7 +2071,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
(max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps) / 2;

DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, {
if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, num_warps)) {
if constexpr (is_invalid_configuration<DTypeQKAccum>(num_frags_x, num_frags_y, num_frags_z,
num_warps)) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x
Expand Down

0 comments on commit 23c02ce

Please sign in to comment.