diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 2b93adc1..786d069d 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -47,10 +47,12 @@ constexpr uint32_t warp_size = 32; namespace { +template constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags_y, uint32_t num_frags_z, uint32_t num_warps) { return ((num_frags_y < 4) || (num_frags_y == 4 && num_frags_z % 2 == 1) || - (num_frags_y > 4 && num_frags_y % 8 != 0)); + (num_frags_y > 4 && num_frags_y % 8 != 0) || + (num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256)); } /*! @@ -1571,8 +1573,9 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( // control num_frags_z for maximum warp occupancy DISPATCH_NUM_FRAGS_Z( min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, - num_frags_z, num_warps)) { + if constexpr (is_invalid_configuration( + num_frags_x, num_frags_y, num_frags_z, + num_warps)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : " @@ -1683,7 +1686,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* // control num_frags_z for maximum warp occupancy DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, num_warps)) { + if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, + num_warps)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x @@ -1862,7 +1866,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps) / 2; DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, num_warps)) { + if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, + num_warps)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x @@ -2066,7 +2071,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps) / 2; DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, num_warps)) { + if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, + num_warps)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x