From 66ee06683eaea7efe724c46df528ae47aa75eca2 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 5 Mar 2024 04:00:58 -0800 Subject: [PATCH] feat: pytorch api of fp8 kv-cache (#156) requested in #150 #155 #125 --- include/flashinfer/attention/decode.cuh | 4 +- python/csrc/batch_decode.cu | 157 +++++++++++++++------- python/csrc/pytorch_extension_utils.h | 28 ++++ python/csrc/single_decode.cu | 40 ++++-- python/flashinfer/prefill.py | 42 ++++++ python/flashinfer/utils.py | 4 + python/setup.py | 5 + python/tests/test_batch_decode_kernels.py | 28 ++-- 8 files changed, 236 insertions(+), 72 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index ca128bee..fb463d97 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -1263,7 +1263,7 @@ cudaError_t BatchDecodeWithPagedKVCache( template -cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o, +cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, float sm_scale, float rope_scale, @@ -1304,7 +1304,7 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType } template -cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o, +cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 271fd0a6..d425af4e 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -47,25 +47,45 @@ std::vector batch_decode_with_padded_kv_cache( } cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - auto o = torch::empty_like(q, q.options()); + auto o = torch::empty_like( + q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); torch::Tensor lse = torch::empty({0}); if (return_lse) { lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32); } - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - c_type* tmp = nullptr; - cudaError_t status = BatchDecodeWithPaddedKVCache( - static_cast(q.data_ptr()), static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, batch_size, - padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout, - PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ", - status); - return true; - }); + bool success; + if (is_float8_tensor(q)) { + success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { + nv_half* tmp = nullptr; + cudaError_t status = BatchDecodeWithPaddedKVCache( + static_cast(q.data_ptr()), static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, batch_size, + padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout, + PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ", + status); + return true; + }); + } else { + success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + c_type* tmp = nullptr; + cudaError_t status = BatchDecodeWithPaddedKVCache( + static_cast(q.data_ptr()), static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, batch_size, + padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout, + PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ", + status); + return true; + }); + } TORCH_CHECK(success, "BatchDecodeWithPaddedKVCache kernel launch failed: supported data type"); if (return_lse) { @@ -93,19 +113,36 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_.SetCUDAStream(torch_current_stream); - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { - DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, { - cudaError_t status = - handler_.BeginForward( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode)); - TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }) - }); + bool success; + if (is_float8_tensor(empty_data)) { + success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] { + DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, { + cudaError_t status = + handler_.BeginForward( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode)); + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }) + }); + } else { + success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { + DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, { + cudaError_t status = + handler_.BeginForward( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode)); + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }) + }); + } TORCH_CHECK(success, "BatchDecodeWithPagedKVCache failed to dispatch with dtype ", empty_data.scalar_type()); @@ -151,31 +188,57 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - torch::Tensor o = torch::empty_like(q, q.options()); + torch::Tensor o = torch::empty_like( + q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); torch::Tensor lse; if (return_lse) { lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32); } - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &handler_, static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, - static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), num_qo_heads, - PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); + + bool success; + if (is_float8_tensor(q)) { + success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { + DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( + &handler_, static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, + static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), num_qo_heads, + PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + }); + return true; }); - return true; - }); + } else { + success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( + &handler_, static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, + static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), num_qo_heads, + PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + }); + return true; + }); + } TORCH_CHECK(success, "BatchDecodeWithPagedKVCache failed to dispatch with dtype ", q.scalar_type()); diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 967b86f1..13b388ec 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -15,9 +15,16 @@ */ #pragma once #include +#include #include #include "generated/dispatch.inc" +#ifdef FLASHINFER_ENABLE_BF16 +#include +#endif +#ifdef FLASHINFER_ENABLE_FP8 +#include +#endif #ifdef FLASHINFER_ENABLE_BF16 #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ @@ -49,6 +56,22 @@ }() #endif +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + return false; \ + } \ + }() + #define _DISPATCH_SWITCH(cond, ...) \ [&]() -> bool { \ switch (cond) { \ @@ -99,3 +122,8 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { #define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) #define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) + +inline bool is_float8_tensor(const torch::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || + tensor.scalar_type() == at::ScalarType::Float8_e5m2; +} diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 1c899158..ad51800d 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -44,19 +44,35 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc kv_len = k.size(1); } cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - auto o = torch::empty_like(q, q.options()); + auto o = torch::empty_like( + q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - cudaError_t status = SingleDecodeWithKVCache( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim, - kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); + bool success; + if (is_float8_tensor(q)) { + success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { + cudaError_t status = SingleDecodeWithKVCache( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim, + kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + } else { + success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + cudaError_t status = SingleDecodeWithKVCache( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim, + kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + } TORCH_CHECK(success, "SingleDecodeWithKVCache kernel launch failed, error: unsupported dtype"); return o; diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 33e4ad62..86efe917 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -17,6 +17,7 @@ import math from typing import Optional import torch +import logging try: from . import _kernels @@ -36,6 +37,7 @@ expand_5d, check_pos_encoding_mode, check_kv_layout, + is_float8, ) @@ -248,6 +250,14 @@ def single_prefill_with_kv_cache_return_lse( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 + if is_float8(q): + logging.warning( + "Our current prefill kernel implementation needs f16 input, the f8 inputs " + " are casted to f16, which could result in performance degradation." + ) + q = q.to(torch.float16) + k = k.to(torch.float16) + v = v.to(torch.float16) return _kernels.single_prefill_with_kv_cache( q, k, @@ -485,6 +495,14 @@ def forward( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 + if is_float8(q): + logging.warning( + "Our current prefill kernel implementation needs f16 input, the f8 inputs " + " are casted to f16, which could result in performance degradation." + ) + q = q.to(torch.float16) + paged_kv_data = paged_kv_data.to(torch.float16) + paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) return self._wrapper.forward( q, @@ -557,6 +575,14 @@ def forward_return_lse( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 + if is_float8(q): + logging.warning( + "Our current prefill kernel implementation needs f16 input, the f8 inputs " + " are casted to f16, which could result in performance degradation." + ) + q = q.to(torch.float16) + paged_kv_data = paged_kv_data.to(torch.float16) + paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) return self._wrapper.forward( q, @@ -769,6 +795,14 @@ def forward( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 + if is_float8(q): + logging.warning( + "Our current prefill kernel implementation needs f16 input, the f8 inputs " + " are casted to f16, which could result in performance degradation." + ) + q = q.to(torch.float16) + k = k.to(torch.float16) + v = v.to(torch.float16) return self._wrapper.forward( q, self._qo_indptr, @@ -838,6 +872,14 @@ def forward_return_lse( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 + if is_float8(q): + logging.warning( + "Our current prefill kernel implementation needs f16 input, the f8 inputs " + " are casted to f16, which could result in performance degradation." + ) + q = q.to(torch.float16) + k = k.to(torch.float16) + v = v.to(torch.float16) return self._wrapper.forward( q, self._qo_indptr, diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 133e839b..664ac879 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -53,3 +53,7 @@ def check_pos_encoding_mode(pos_encoding_mode: str): def check_kv_layout(kv_layout: str): if not hasattr(TensorLayout, kv_layout): raise KeyError("Invalide kv_layout {}".format(kv_layout)) + + +def is_float8(x: torch.Tensor): + return x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] diff --git a/python/setup.py b/python/setup.py index dc3463eb..a8bce576 100644 --- a/python/setup.py +++ b/python/setup.py @@ -30,6 +30,9 @@ root = pathlib.Path(__name__).parent enable_bf16 = True +# NOTE(Zihao): we haven't utilized fp8 tensor cores yet, so there is no +# cuda arch check for fp8 at the moment. +enable_fp8 = True for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) if arch < 75: @@ -40,6 +43,8 @@ if enable_bf16: torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16") +if enable_fp8: + torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_FP8") def get_instantiation_cu() -> List[str]: diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index 213ceea3..739dd823 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -30,6 +30,9 @@ @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +@pytest.mark.parametrize( + "dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +) def test_batch_decode_with_paged_kv_cache( batch_size, kv_len, @@ -40,16 +43,15 @@ def test_batch_decode_with_paged_kv_cache( head_dim, kv_layout, pos_encoding_mode, + dtype, ): - q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half() + q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( - torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0) if kv_layout == "HND" - else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) - .to(0) - .half() + else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim).to(0) ) kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq kv_indices = torch.arange(0, total_num_pages).to(0).int() @@ -68,9 +70,9 @@ def test_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - "float16", + dtype, ) - o = wrapper.forward(q, kv_data, pos_encoding_mode=pos_encoding_mode) + o = wrapper.forward(q, kv_data.to(dtype), pos_encoding_mode=pos_encoding_mode) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] @@ -90,7 +92,7 @@ def test_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ) + ).to(dtype) vi = torch.cat( [ kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] @@ -105,7 +107,7 @@ def test_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ) + ).to(dtype) o_ref_i = flashinfer.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) @@ -115,5 +117,9 @@ def test_batch_decode_with_paged_kv_cache( if __name__ == "__main__": - test_batch_decode_with_paged_kv_cache(12, 54, 37, 8, 8, 8, 128, "HND", "NONE") - test_batch_decode_with_paged_kv_cache(12, 54, 37, 1, 8, 8, 128, "HND", "NONE") + test_batch_decode_with_paged_kv_cache( + 12, 54, 37, 8, 8, 8, 128, "HND", "NONE", torch.float16 + ) + test_batch_decode_with_paged_kv_cache( + 12, 54, 37, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2 + )