diff --git a/.github/workflows/release_wheel.yml b/.github/workflows/release_wheel.yml index 321d268d..aa9b1265 100644 --- a/.github/workflows/release_wheel.yml +++ b/.github/workflows/release_wheel.yml @@ -18,7 +18,7 @@ on: # required: true env: - TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX" + TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX" jobs: build: diff --git a/docs/installation.rst b/docs/installation.rst index 95fbf84a..266ebbdb 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -19,7 +19,7 @@ Prerequisites - Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version. -- Supported GPU architectures: ``sm80``, ``sm86``, ``sm89``, ``sm90`` (``sm75`` / ``sm70`` support is working in progress). +- Supported GPU architectures: ``sm75``, ``sm80``, ``sm86``, ``sm89``, ``sm90``. Quick Start ^^^^^^^^^^^ diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 09ef0941..c1bf4cc7 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -18,13 +18,10 @@ #include #include #include - -#include -#ifdef FLASHINFER_ENABLE_FP8 #include -#endif #include +#include #include #include #include @@ -537,6 +534,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( j] + tx * vec_size; } + // load k tiles #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { @@ -597,11 +595,7 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo return 512U; } } else { -#ifdef FLASHINFER_ENABLE_BF16 return 128U; -#else - return 64U; -#endif } } @@ -639,8 +633,8 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; + auto compute_capacity = GetCudaComputeCapability(); static_assert(bdx <= 32U); DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { constexpr uint32_t bdy = GROUP_SIZE; @@ -649,69 +643,74 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, constexpr uint32_t bdz = num_threads / (bdx * bdy); tensor_info_t info(1, seq_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; - const uint32_t smem_size = - 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + - 2U * bdy * bdz * sizeof(float); - auto kernel = SingleDecodeWithKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - if (seq_len <= 256 || tmp == nullptr) { - // no need to use partition-kv kernel - dim3 nblks = dim3(1, num_kv_heads); - dim3 nthrs = dim3(bdx, bdy, bdz); - float* lse = nullptr; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&lse, - (void*)&info, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta, - (void*)&seq_len}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } else { - // use partition-kv kernel - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, - num_threads, smem_size)); - uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); - uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; - uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); - uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); - dim3 nblks = dim3(num_chunks, num_kv_heads); - if (nblks.x == 0 || nblks.y == 0) { - std::ostringstream err_msg; - err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; - throw std::runtime_error(err_msg.str()); - } - dim3 nthrs = dim3(bdx, bdy, bdz); - float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM); - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&tmp, - (void*)&tmp_lse, - (void*)&info, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta, - (void*)&kv_chunk_size}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + const uint32_t smem_size = + 2U * NUM_STAGES_SMEM * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + + 2U * bdy * bdz * sizeof(float); + auto kernel = SingleDecodeWithKVCacheKernel; FLASHINFER_CUDA_CALL( - MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); - } + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (seq_len <= 256 || tmp == nullptr) { + // no need to use partition-kv kernel + dim3 nblks = dim3(1, num_kv_heads); + dim3 nthrs = dim3(bdx, bdy, bdz); + float* lse = nullptr; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&lse, + (void*)&info, + (void*)&window_left, + (void*)&logits_soft_cap, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta, + (void*)&seq_len}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // use partition-kv kernel + int num_blocks_per_sm = 0; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); + uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; + uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); + uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); + dim3 nblks = dim3(num_chunks, num_kv_heads); + if (nblks.x == 0 || nblks.y == 0) { + std::ostringstream err_msg; + err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; + throw std::runtime_error(err_msg.str()); + } + dim3 nthrs = dim3(bdx, bdy, bdz); + float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM); + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&tmp, + (void*)&tmp_lse, + (void*)&info, + (void*)&window_left, + (void*)&logits_soft_cap, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta, + (void*)&kv_chunk_size}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL( + MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); + } + }); }); return cudaSuccess; } @@ -730,7 +729,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( const uint32_t num_kv_heads = paged_kv.num_heads; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - constexpr uint32_t num_stages_smem = 2U; + auto compute_capacity = GetCudaComputeCapability(); constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { @@ -738,58 +737,63 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; - const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); - auto kernel = - BatchDecodeWithPagedKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - if (tmp_v == nullptr) { - // do not use partition-kv kernel - bool partition_kv = false; - dim3 nblks(padded_batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - - void* args[] = {(void*)&q, - (void*)&q_offset, - (void*)&paged_kv, - (void*)&kv_partition_info, - (void*)&o, - (void*)&lse, - (void*)&block_valid_mask, - (void*)&partition_kv, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } else { - // use partition-kv kernel - bool partition_kv = true; - void* args[] = {(void*)&q, - (void*)&q_offset, - (void*)&paged_kv, - (void*)&kv_partition_info, - (void*)&tmp_v, - (void*)&tmp_s, - (void*)&block_valid_mask, - (void*)&partition_kv, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - dim3 nblks(padded_batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse, - kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream)); - } + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + const uint32_t smem_size = + 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), + 2 * bdy * bdz * sizeof(float)); + auto kernel = + BatchDecodeWithPagedKVCacheKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (tmp_v == nullptr) { + // do not use partition-kv kernel + bool partition_kv = false; + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + + void* args[] = {(void*)&q, + (void*)&q_offset, + (void*)&paged_kv, + (void*)&kv_partition_info, + (void*)&o, + (void*)&lse, + (void*)&block_valid_mask, + (void*)&partition_kv, + (void*)&window_left, + (void*)&logits_soft_cap, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // use partition-kv kernel + bool partition_kv = true; + void* args[] = {(void*)&q, + (void*)&q_offset, + (void*)&paged_kv, + (void*)&kv_partition_info, + (void*)&tmp_v, + (void*)&tmp_s, + (void*)&block_valid_mask, + (void*)&partition_kv, + (void*)&window_left, + (void*)&logits_soft_cap, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse, + kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream)); + } + }); }); return cudaSuccess; } diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index bcb9bc82..e29b99c4 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -145,51 +145,53 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - constexpr uint32_t num_stages_smem = 2U; - constexpr uint32_t bdx = HEAD_DIM / vec_size; - static_assert(bdx <= 32); - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = std::max(128U, bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; - const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; - const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); - - auto kernel = - BatchDecodeWithPagedKVCacheKernel; - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, - num_threads, smem_size)); - max_grid_size = num_blocks_per_sm * num_sm; - if (batch_size * num_kv_heads >= max_grid_size) { - split_kv = false; - new_batch_size = batch_size; - } else { - // compute max_num_pages_per_batch and new_batch_size - std::vector num_pages(batch_size); - for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; - } - std::tie(max_num_pages_per_batch, new_batch_size) = - PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, num_pages, - std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { - // do not use partition-kv kernel for short sequence, when not using CUDAGraph + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32); + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = std::max(128U, bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; + const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; + const uint32_t smem_size = + 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); + + auto kernel = + BatchDecodeWithPagedKVCacheKernel; + int num_blocks_per_sm = 0; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + max_grid_size = num_blocks_per_sm * num_sm; + if (batch_size * num_kv_heads >= max_grid_size) { split_kv = false; + new_batch_size = batch_size; } else { - // when using CUDAGraph, we always use partition-kv kernel - split_kv = true; + // compute max_num_pages_per_batch and new_batch_size + std::vector num_pages(batch_size); + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; + } + std::tie(max_num_pages_per_batch, new_batch_size) = + PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( + max_grid_size, num_kv_heads, num_pages, std::max(128 / page_size, 1U)); + if (new_batch_size == batch_size && !enable_cuda_graph) { + // do not use partition-kv kernel for short sequence, when not using CUDAGraph + split_kv = false; + } else { + // when using CUDAGraph, we always use partition-kv kernel + split_kv = true; + } } - } - return cudaSuccess; + return cudaSuccess; + }) } /*! @@ -556,11 +558,18 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz if (avg_packed_qo_len > 64 && head_dim < 256) { warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2) } else { - if (avg_packed_qo_len > 16) { - warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (avg_packed_qo_len > 16) { + warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) + } else { + // avg_packed_qo_len <= 16 + warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1) + } } else { - // avg_packed_qo_len <= 16 - warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1) + // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout + warp_layout = WarpLayout::k4x1x1; } } const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout); diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 006d9753..b7c18ef0 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -18,9 +18,7 @@ #include #include #include -#ifdef FLASHINFER_ENABLE_FP8 #include -#endif #include #include "../cp_async.cuh" @@ -1936,10 +1934,17 @@ cudaError_t SinglePrefillWithKVCacheDispatched( if (unpacked_qo_len > 64 && HEAD_DIM < 256) { warp_layout = WarpLayout::k4x1x2; } else { - if (unpacked_qo_len > 16) { - warp_layout = WarpLayout::k4x1x1; + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (unpacked_qo_len > 16) { + warp_layout = WarpLayout::k4x1x1; + } else { + warp_layout = WarpLayout::k1x4x1; + } } else { - warp_layout = WarpLayout::k1x4x1; + // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout + warp_layout = WarpLayout::k4x1x1; } } diff --git a/include/flashinfer/mma.cuh b/include/flashinfer/mma.cuh index 82f457a5..3c54a3f1 100644 --- a/include/flashinfer/mma.cuh +++ b/include/flashinfer/mma.cuh @@ -18,9 +18,7 @@ #include #include -#ifdef FLASHINFER_ENABLE_FP8 #include -#endif #include #include @@ -206,7 +204,6 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { #endif } -#ifdef FLASHINFER_ENABLE_FP8 /*! * \brief Wrapper of two mma m16n8k32 instructions for row major and column major f8 matrix * multiplication, accumulated in f32. @@ -307,7 +304,6 @@ __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uin "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"); #endif } -#endif /*! * \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix @@ -404,79 +400,82 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u } } #elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED) - if constexpr (mma_mode == MMAMode::kInit) { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + if constexpr (std::is_same::value) { + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } } else { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); } #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); #endif } -#ifdef FLASHINFER_ENABLE_FP8 /*! * \brief Use mma instructions to compute rowsum. */ @@ -515,7 +514,6 @@ __device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) { "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"); #endif } -#endif /*! * \brief Use mma instructions to compute rowsum. @@ -551,27 +549,30 @@ __device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { "r"(1065369472), "f"(d[0]), "f"(d[1])); } #elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED) - static_assert(std::is_same::value, "bf16 mma instruction is not supported on sm_75"); - asm volatile( - "{\n" - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, _, %1, _}," - "{%2, %3}," - "{%4}," - "{%5, 0., %6, 0.};\n" - "}\n" - : "=f"(d[0]), "=f"(d[1]) - : "r"(s_u32[0]), "r"(s_u32[1]), "r"(1006648320), "f"(d[0]), "f"(d[1])); - asm volatile( - "{\n" - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, _, %1, _}," - "{%2, %3}," - "{%4}," - "{%5, 0., %6, 0.};\n" - "}\n" - : "=f"(d[0]), "=f"(d[1]) - : "r"(s_u32[2]), "r"(s_u32[3]), "r"(1006648320), "f"(d[0]), "f"(d[1])); + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3}," + "{%4}," + "{%5, 0., %6, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(1006648320), "f"(d[0]), "f"(d[1])); + asm volatile( + "{\n" + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3}," + "{%4}," + "{%5, 0., %6, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[2]), "r"(s_u32[3]), "r"(1006648320), "f"(d[0]), "f"(d[1])); + } else { + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); + } #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); #endif diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 3cda7415..4df2a006 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -40,6 +40,15 @@ using namespace cub; __VA_ARGS__ \ } +#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t BLOCK_THREADS = 1024; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t BLOCK_THREADS = 512; \ + __VA_ARGS__ \ + } + constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; @@ -686,25 +695,28 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t max_top_k_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &success, - &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - })}); - return cudaSuccess; + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = + sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &output, &success, + &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; + }); } template @@ -764,25 +776,28 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k IdType* output, bool* success, uint32_t batch_size, IdType top_k_val, T top_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &top_k_arr, &top_p_arr, &output, - &success, &top_k_val, &top_p_val, &d, &max_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKTopPSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - })}); - return cudaSuccess; + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = + sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &top_k_arr, &top_p_arr, &output, + &success, &top_k_val, &top_p_val, &d, &max_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKTopPSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; + }); } template @@ -1166,7 +1181,7 @@ template cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, uint32_t batch_size, float top_p_val, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; + constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1186,40 +1201,44 @@ template cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKRenormProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; }); - return cudaSuccess; } template cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKMaskLogitsKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKMaskLogitsKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; }); - return cudaSuccess; } template #include #include @@ -228,6 +229,15 @@ } \ } +#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t NUM_STAGES_SMEM = 2; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t NUM_STAGES_SMEM = 1; \ + __VA_ARGS__ \ + } + namespace flashinfer { template @@ -235,6 +245,15 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { return (x + y - 1) / y; } +inline std::pair GetCudaComputeCapability() { + int device_id = 0; + cudaGetDevice(&device_id); + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id); + return std::make_pair(major, minor); +} + template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { std::vector host_array(size); diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index a702c8ee..d6895041 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -15,19 +15,15 @@ */ #pragma once #include +#include #include +#include #include #include #include #include "generated/dispatch.inc" -#ifdef FLASHINFER_ENABLE_BF16 -#include -#endif -#ifdef FLASHINFER_ENABLE_FP8 -#include -#endif using namespace flashinfer; diff --git a/python/setup.py b/python/setup.py index 2fd605be..22d2878a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -32,17 +32,14 @@ root = pathlib.Path(__name__).parent -enable_bf16 = True -# NOTE(Zihao): we haven't utilized fp8 tensor cores yet, so there is no # cuda arch check for fp8 at the moment. -enable_fp8 = True for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) if arch < 75: raise RuntimeError("FlashInfer requires sm75+") - elif arch == 75: - # disable bf16 for sm75 - enable_bf16 = False + +enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1" +enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1" if enable_bf16: torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16")