From b98cc28f91aadbb8b831611f3676da92f892211d Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Wed, 28 Aug 2024 10:01:22 -0700 Subject: [PATCH] [Core][Kernels] Use FlashInfer backend for FP8 KV Cache when available. (#7798) Co-authored-by: Simon Mo --- tests/kernels/test_flashinfer.py | 228 +++++++++++++++++++++++++- vllm/attention/backends/flashinfer.py | 29 +++- vllm/attention/selector.py | 4 + 3 files changed, 249 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index f109792ad251b..67f12cf1ee08e 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -73,11 +73,14 @@ def ref_paged_attn( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode -def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], - num_heads: Tuple[int, - int], head_size: int, - dtype: torch.dtype, block_size: int, - soft_cap: Optional[float]) -> None: +def test_flashinfer_decode_with_paged_kv( + kv_lens: List[int], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) num_seqs = len(kv_lens) @@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], scale = head_size**-0.5 query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_value_cache = torch.randn(NUM_BLOCKS, 2, block_size, @@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", use_tensor_cores=( - (num_query_heads//num_kv_heads) not in (1, 2, 4, 8)) + (num_query_heads//num_kv_heads) > 4) ) wrapper.begin_forward(kv_indptr, kv_indices, @@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], soft_cap=soft_cap) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) +@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)]) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +def test_flashinfer_prefill_with_paged_fp8_kv( + seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], + head_size: int, dtype: torch.dtype, block_size: int, + soft_cap: Optional[float]) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + kv_cache_dtype = torch.float8_e4m3fn + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + NUM_BLOCKS_FP8 = 2048 + key_value_cache = torch.randn(NUM_BLOCKS_FP8, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + + k_scale = key_cache.amax().item() / 448.0 + v_scale = value_cache.amax().item() / 448.0 + + kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], + dim=1).to(kv_cache_dtype) + + assert (kv_cache_fp8.shape == key_value_cache.shape) + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS_FP8, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + qo_indptr = [0] + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + qo_indptr.append(qo_indptr[-1] + query_lens[i]) + + qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32) + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD") + wrapper.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + ) + + output = wrapper.forward(query, + kv_cache_fp8, + logits_soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale) + + ref_output = ref_paged_attn(query=query, + key_cache=key_cache.squeeze(1), + value_cache=value_cache.squeeze(1), + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap) + del query + del block_tables + # verify prefill fp8 + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)]) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@torch.inference_mode +def test_flashinfer_decode_with_paged_fp8_kv( + kv_lens: List[int], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: + # test doesn't work for num_heads = (16,16) + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + use_tensor_cores = (num_query_heads // num_kv_heads) > 4 + kv_cache_dtype = torch.float8_e4m3fn + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + NUM_BLOCKS_FP8 = 2048 + key_value_cache = torch.randn(NUM_BLOCKS_FP8, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + + k_scale = key_cache.amax().item() / 448.0 + v_scale = value_cache.amax().item() / 448.0 + + key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype) + value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype) + assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1) + kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS_FP8, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.\ + BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", + use_tensor_cores=use_tensor_cores) + wrapper.begin_forward(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + data_type=dtype) + output = wrapper.forward(query, + kv_cache_fp8, + logits_soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale) + key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) + value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) + + ref_output = ref_paged_attn(query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap) + # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue + torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a8d76b79ff204..ca42f77f51cd4 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -83,6 +83,15 @@ def copy_blocks( def get_supported_head_sizes() -> List[int]: return [64, 128, 256] + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + return ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + class FlashInferState(AttentionState): @@ -177,9 +186,9 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_indices_buffer, _last_page_len_buffer, "NHD", use_tensor_cores) - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) paged_kv_indptr_tensor_host = torch.arange(0, batch_size + 1, dtype=torch.int32) @@ -340,7 +349,7 @@ def begin_forward(self): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - data_type=self.data_type) + ) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -366,7 +375,8 @@ def prefill_metadata(self) -> Optional["FlashInferMetadata"]: def decode_metadata(self) -> Optional["FlashInferMetadata"]: # Currently chunked prefill is not supported if self.num_prefills > 0: - assert self.num_decode_tokens == 0 + assert self.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") return None return self @@ -578,6 +588,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], kv_cache_dtype = get_kv_cache_torch_dtype( self.runner.kv_cache_dtype, self.runner.model_config.dtype) + return FlashInferMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -661,7 +672,6 @@ def forward( if attn_metadata.num_decode_tokens > 0: assert attn_metadata.num_prefill_tokens == 0, ( "Chunked prefill is not supported with flashinfer yet.") - if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( @@ -674,6 +684,11 @@ def forward( k_scale, v_scale, ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache in fp8 + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) query = query.contiguous( ) # Flashinfer requires query to be contiguous @@ -711,5 +726,7 @@ def forward( query, kv_cache, sm_scale=self.scale, - logits_soft_cap=self.logits_soft_cap) + logits_soft_cap=self.logits_soft_cap, + k_scale=k_scale, + v_scale=v_scale) return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 54558fc2d7e53..c0e592c8b12a0 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -226,6 +226,10 @@ def which_attn_to_use( elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): logger.info( "Cannot use FlashAttention-2 backend for FP8 KV cache.") + logger.warning( + "Please use FlashInfer backend with FP8 KV Cache for " + "better performance by set environment " + "VLLM_ATTENTION_BACKEND=FLASHINFER") selected_backend = _Backend.XFORMERS elif block_size % 16 != 0: logger.info(