Skip to content

Commit

Permalink
feat: pytorch api of fp8 kv-cache (#156)
Browse files Browse the repository at this point in the history
requested in #150 #155 #125
  • Loading branch information
yzh119 authored Mar 5, 2024
1 parent de129b9 commit 66ee066
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 72 deletions.
4 changes: 2 additions & 2 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,7 @@ cudaError_t BatchDecodeWithPagedKVCache(

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o,
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
float sm_scale, float rope_scale,
Expand Down Expand Up @@ -1304,7 +1304,7 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType
}

template <typename DTypeIn, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o,
cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim,
Expand Down
157 changes: 110 additions & 47 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,45 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache(
}

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());
auto o = torch::empty_like(
q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type()));
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32);
}

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
c_type* tmp = nullptr;
cudaError_t status = BatchDecodeWithPaddedKVCache<c_type, c_type>(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k_padded.data_ptr()),
static_cast<c_type*>(v_padded.data_ptr()), static_cast<c_type*>(o.data_ptr()),
/*tmp=*/tmp,
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ",
status);
return true;
});
bool success;
if (is_float8_tensor(q)) {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] {
nv_half* tmp = nullptr;
cudaError_t status = BatchDecodeWithPaddedKVCache<c_type, nv_half>(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k_padded.data_ptr()),
static_cast<c_type*>(v_padded.data_ptr()), static_cast<nv_half*>(o.data_ptr()),
/*tmp=*/tmp,
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ",
status);
return true;
});
} else {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
c_type* tmp = nullptr;
cudaError_t status = BatchDecodeWithPaddedKVCache<c_type, c_type>(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k_padded.data_ptr()),
static_cast<c_type*>(v_padded.data_ptr()), static_cast<c_type*>(o.data_ptr()),
/*tmp=*/tmp,
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ",
status);
return true;
});
}
TORCH_CHECK(success, "BatchDecodeWithPaddedKVCache kernel launch failed: supported data type");

if (return_lse) {
Expand Down Expand Up @@ -93,19 +113,36 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
cudaError_t status =
handler_.BeginForward<PageStorage::kIndices, KV_LAYOUT, c_type, c_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode));
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
})
});
bool success;
if (is_float8_tensor(empty_data)) {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
cudaError_t status =
handler_.BeginForward<PageStorage::kIndices, KV_LAYOUT, c_type, nv_half, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode));
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
})
});
} else {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
cudaError_t status =
handler_.BeginForward<PageStorage::kIndices, KV_LAYOUT, c_type, c_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode));
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
})
});
}

TORCH_CHECK(success, "BatchDecodeWithPagedKVCache failed to dispatch with dtype ",
empty_data.scalar_type());
Expand Down Expand Up @@ -151,31 +188,57 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor o = torch::empty_like(
q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type()));
torch::Tensor lse;
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32);
}
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
paged_kv_t<PageStorage::kIndices, KV_LAYOUT, c_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size,
static_cast<c_type*>(paged_kv_data.data_ptr()),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<PageStorage::kIndices, KV_LAYOUT,
c_type, c_type, int32_t>(
&handler_, static_cast<c_type*>(q.data_ptr()), /*q_offset=*/nullptr, paged_kv,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), num_qo_heads,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));

bool success;
if (is_float8_tensor(q)) {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
paged_kv_t<PageStorage::kIndices, KV_LAYOUT, c_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size,
static_cast<c_type*>(paged_kv_data.data_ptr()),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<PageStorage::kIndices, KV_LAYOUT,
c_type, nv_half, int32_t>(
&handler_, static_cast<c_type*>(q.data_ptr()), /*q_offset=*/nullptr, paged_kv,
static_cast<nv_half*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), num_qo_heads,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
});
return true;
});
return true;
});
} else {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, {
paged_kv_t<PageStorage::kIndices, KV_LAYOUT, c_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size,
static_cast<c_type*>(paged_kv_data.data_ptr()),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<PageStorage::kIndices, KV_LAYOUT,
c_type, c_type, int32_t>(
&handler_, static_cast<c_type*>(q.data_ptr()), /*q_offset=*/nullptr, paged_kv,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), num_qo_heads,
PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
});
return true;
});
}

TORCH_CHECK(success, "BatchDecodeWithPagedKVCache failed to dispatch with dtype ",
q.scalar_type());
Expand Down
28 changes: 28 additions & 0 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
*/
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <cuda_fp16.h>
#include <torch/extension.h>

#include "generated/dispatch.inc"
#ifdef FLASHINFER_ENABLE_BF16
#include <cuda_bf16.h>
#endif
#ifdef FLASHINFER_ENABLE_FP8
#include <cuda_fp8.h>
#endif

#ifdef FLASHINFER_ENABLE_BF16
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
Expand Down Expand Up @@ -49,6 +56,22 @@
}()
#endif

#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Float8_e4m3fn: { \
using c_type = __nv_fp8_e4m3; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Float8_e5m2: { \
using c_type = __nv_fp8_e5m2; \
return __VA_ARGS__(); \
} \
default: \
return false; \
} \
}()

#define _DISPATCH_SWITCH(cond, ...) \
[&]() -> bool { \
switch (cond) { \
Expand Down Expand Up @@ -99,3 +122,8 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)

#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)

inline bool is_float8_tensor(const torch::Tensor& tensor) {
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn ||
tensor.scalar_type() == at::ScalarType::Float8_e5m2;
}
40 changes: 28 additions & 12 deletions python/csrc/single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,35 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
kv_len = k.size(1);
}
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto o = torch::empty_like(q, q.options());
auto o = torch::empty_like(
q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type()));

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
cudaError_t status = SingleDecodeWithKVCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim,
kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
return true;
});
bool success;
if (is_float8_tensor(q)) {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] {
cudaError_t status = SingleDecodeWithKVCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<nv_half*>(o.data_ptr()),
static_cast<nv_half*>(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim,
kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
return true;
});
} else {
success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
cudaError_t status = SingleDecodeWithKVCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim,
kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
return true;
});
}

TORCH_CHECK(success, "SingleDecodeWithKVCache kernel launch failed, error: unsupported dtype");
return o;
Expand Down
42 changes: 42 additions & 0 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import math
from typing import Optional
import torch
import logging

try:
from . import _kernels
Expand All @@ -36,6 +37,7 @@
expand_5d,
check_pos_encoding_mode,
check_kv_layout,
is_float8,
)


Expand Down Expand Up @@ -248,6 +250,14 @@ def single_prefill_with_kv_cache_return_lse(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
return _kernels.single_prefill_with_kv_cache(
q,
k,
Expand Down Expand Up @@ -485,6 +495,14 @@ def forward(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
paged_kv_data = paged_kv_data.to(torch.float16)

paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
return self._wrapper.forward(
q,
Expand Down Expand Up @@ -557,6 +575,14 @@ def forward_return_lse(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
paged_kv_data = paged_kv_data.to(torch.float16)

paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
return self._wrapper.forward(
q,
Expand Down Expand Up @@ -769,6 +795,14 @@ def forward(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
return self._wrapper.forward(
q,
self._qo_indptr,
Expand Down Expand Up @@ -838,6 +872,14 @@ def forward_return_lse(
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
k = k.to(torch.float16)
v = v.to(torch.float16)
return self._wrapper.forward(
q,
self._qo_indptr,
Expand Down
Loading

0 comments on commit 66ee066

Please sign in to comment.