From ef5bb500af59b1fdf835228fa6dd93f93359d849 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 18 Dec 2024 20:17:12 -0800 Subject: [PATCH] [V1] Simplify prefix caching logic by removing `num_evictable_computed_blocks` (#11310) --- vllm/v1/core/kv_cache_manager.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 61a3f5fd6d841..78efacccfa078 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -201,23 +201,15 @@ def allocate_slots( f"num_tokens must be greater than 0, got {num_tokens}") # Touch the computed blocks to make sure they won't be evicted. - num_evictable_computed_blocks = 0 if self.enable_caching: self._touch(computed_blocks) - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = len( - [blk for blk in computed_blocks if blk.ref_cnt == 0]) else: assert not computed_blocks, ( "Computed blocks should be empty when " "prefix caching is disabled") num_required_blocks = cdiv(num_tokens, self.block_size) - if (num_required_blocks > self.free_block_queue.num_free_blocks - - num_evictable_computed_blocks): + if (num_required_blocks > self.free_block_queue.num_free_blocks): # Cannot allocate new blocks. return None @@ -225,8 +217,7 @@ def allocate_slots( # preallocated blocks. num_new_blocks = min( num_required_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks - - num_evictable_computed_blocks, + self.free_block_queue.num_free_blocks, # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape # [..., max_num_blocks_per_req].