From 7a43c65c0ee1346d92ac995fae4531ad24cba34f Mon Sep 17 00:00:00 2001 From: Jingxin Pan Date: Mon, 30 Dec 2024 13:03:24 +0800 Subject: [PATCH] [fix] fix CI warnings Signed-off-by: Jingxin Pan --- .../offline_inference_with_global_prefix.py | 6 +- vllm/core/block/cpu_gpu_block_allocator.py | 8 ++ vllm/core/block/interfaces.py | 10 +++ vllm/core/block/naive_block.py | 8 ++ vllm/core/scheduler.py | 13 ++-- vllm/engine/llm_engine.py | 8 +- vllm/global_cache.py | 65 ++++++++++------ vllm/worker/model_runner.py | 74 +++++++++++++------ vllm/worker/worker_base.py | 9 ++- 9 files changed, 143 insertions(+), 58 deletions(-) mode change 100644 => 100755 vllm/core/block/cpu_gpu_block_allocator.py mode change 100644 => 100755 vllm/core/block/interfaces.py mode change 100644 => 100755 vllm/core/block/naive_block.py mode change 100644 => 100755 vllm/core/scheduler.py mode change 100644 => 100755 vllm/worker/worker_base.py diff --git a/examples/offline_inference_with_global_prefix.py b/examples/offline_inference_with_global_prefix.py index ddfc3015e3972..946969c8e8147 100755 --- a/examples/offline_inference_with_global_prefix.py +++ b/examples/offline_inference_with_global_prefix.py @@ -85,8 +85,10 @@ """ We can simulate the global prefix cache this way: -1. Run the vllm instance with APC for some time, so some prompts may not hit APC as they are old and evicted -2. Delete the first vllm instance and start a new one. In this case, global kv cache can be hit directly +1. Run the vllm instance with APC for some time, so some prompts may not +hit APC as they are old and evicted +2. Delete the first vllm instance and start a new one. In this case, +global kv cache can be hit directly Here we demo the second option. """ # Destroy the LLM object and free up the GPU memory. diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py old mode 100644 new mode 100755 index 3a57487a6cd8a..96ce2038a213d --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -416,6 +416,14 @@ def computed(self): def computed(self, value): self._proxy.computed = value + @property + def global_computed(self): + return self._proxy.global_computed + + @global_computed.setter + def global_computed(self, value): + self._proxy.global_computed = value + @property def last_accessed(self) -> float: return self._proxy.last_accessed diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py old mode 100644 new mode 100755 index 985a1098b6cd1..d306f030885d3 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -66,6 +66,16 @@ def computed(self, value) -> bool: """Should be only used by PrefixCacingAllocator""" raise NotImplementedError + @property + @abstractmethod + def global_computed(self) -> bool: + raise NotImplementedError + + @global_computed.setter + @abstractmethod + def global_computed(self, value) -> bool: + raise NotImplementedError + @property @abstractmethod def last_accessed(self) -> float: diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py old mode 100644 new mode 100755 index 9b94918ab38ef..a7dde103bb214 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -404,6 +404,14 @@ def computed(self) -> bool: def computed(self, value) -> None: raise NotImplementedError + @property + def global_computed(self) -> bool: + raise NotImplementedError + + @global_computed.setter + def global_computed(self, value) -> None: + raise NotImplementedError + @property def last_accessed(self) -> float: raise NotImplementedError diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py old mode 100644 new mode 100755 index 1011d15f330ff..3bc00b12c22cb --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1336,19 +1336,22 @@ def schedule( seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) - if self.cache_config.enable_prefix_caching and self.cache_config.num_global_cache_blocks > 0 and seq_group.is_prefill(): + if (self.cache_config.enable_prefix_caching and + self.cache_config.num_global_cache_blocks > 0 and + seq_group.is_prefill()): global_computed_list = [] - block_hash_dict = {} + block_h_dict = {} for block_id in block_tables[seq_id]: - for block in self.block_manager.block_tables[seq_id].blocks: + for block in \ + self.block_manager.block_tables[seq_id].blocks: if block.block_id == block_id: if block.global_computed: global_computed_list.append(block_id) if block.content_hash is not None: - block_hash_dict[block_id] = block.content_hash + block_h_dict[block_id] = block.content_hash break block_global_computed_tables[seq_id] = global_computed_list - block_hash_map[seq_id] = block_hash_dict + block_hash_map[seq_id] = block_h_dict self.block_manager.access_all_blocks_in_seq(seq, now) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dc1cac9dde8f4..e309ab4871f82 100755 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -371,8 +371,12 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: if self.model_config.use_async_output_proc else None) for v_id in range(self.parallel_config.pipeline_parallel_size) ] - if self.cache_config.enable_prefix_caching and self.cache_config.num_global_cache_blocks > 0: - global_cache_instance.setGlabalCacheBlockNum(self.model_config, self.cache_config, self.model_executor.driver_worker.cache_engine[0].dtype) + if (self.cache_config.enable_prefix_caching and + self.cache_config.num_global_cache_blocks > 0): + global_cache_instance.setGlabalCacheBlockNum( + self.model_config, + self.cache_config, + self.model_executor.driver_worker.cache_engine[0].dtype) # Metric Logging. if self.log_stats: diff --git a/vllm/global_cache.py b/vllm/global_cache.py index cd6c12b4ebfbb..0e58905717479 100755 --- a/vllm/global_cache.py +++ b/vllm/global_cache.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Deque, Dict +from typing import Deque, Dict, Optional import torch from vllm.config import ModelConfig, CacheConfig from vllm.logger import init_logger @@ -9,8 +9,10 @@ class GlobalCache: """ - For now just use a simple Dict to store golbal kv cache and a Deque to evict the least used key. - It can be easily extended and integrated with other kv cache pools that shared with other vllm instances. + For now just use a simple Dict to store golbal kv cache + and a Deque to evict the least used key. + It can be easily extended and integrated with other kv cache pools + that shared with other vllm instances. """ def __init__(self, max_mem_util: float): self.cachedBlockNum: int = 0 @@ -19,7 +21,10 @@ def __init__(self, max_mem_util: float): self.blockHashDict_v: Dict[int, Dict[int, torch.Tensor]] = {} self.cachedBlockHashQ: Deque[int] = deque() - def setGlabalCacheBlockNum(self, model_config: ModelConfig, cache_config: CacheConfig, dtype: torch.dtype): + def setGlabalCacheBlockNum( + self, model_config: ModelConfig, + cache_config: CacheConfig, + dtype: torch.dtype): if self.cachedBlockNum > 0: logger.warning("global kv cache already enabled") return @@ -28,42 +33,54 @@ def setGlabalCacheBlockNum(self, model_config: ModelConfig, cache_config: CacheC return available_mem = psutil.virtual_memory().available * self.max_mem_util num_kv_heads = model_config.hf_text_config.num_attention_heads - head_size = model_config.hf_text_config.hidden_size // model_config.hf_text_config.num_attention_heads + head_size = (model_config.hf_text_config.hidden_size // + model_config.hf_text_config.num_attention_heads) num_attention_layers = model_config.hf_config.num_hidden_layers dtype_size = torch.tensor([], dtype=dtype).element_size() - key_size_bytes = dtype_size * cache_config.block_size * num_kv_heads * head_size * num_attention_layers - if key_size_bytes * 2 * cache_config.num_global_cache_blocks > available_mem: - logger.warning("num_global_cache_blocks too large, can not enable global kv cache, at most %d blocks can be used", available_mem // (key_size_bytes * 2)) + key_size_bytes = (dtype_size * cache_config.block_size * + num_kv_heads * head_size * num_attention_layers) + if (key_size_bytes * 2 * cache_config.num_global_cache_blocks > + available_mem): + logger.warning("num_global_cache_blocks too large, can not enable " + "global kv cache, at most %d blocks can be used", + available_mem // (key_size_bytes * 2)) return self.cachedBlockNum = cache_config.num_global_cache_blocks - logger.info("global kv cache enabled") + logger.info("global kv cache enabled: %d", self.cachedBlockNum) - def writeCache(self, block_hash: int, layer_idx: int, k_block_tensor: torch.Tensor, v_block_tensor: torch.Tensor): + def writeCache( + self, h: int, idx: int, + k_block_tensor: torch.Tensor, v_block_tensor: torch.Tensor): if self.cachedBlockNum == 0: return if len(self.cachedBlockHashQ) == self.cachedBlockNum: poped_block_hash = self.cachedBlockHashQ.popleft() del self.blockHashDict_k[poped_block_hash] del self.blockHashDict_v[poped_block_hash] - if block_hash not in self.blockHashDict_k or block_hash not in self.blockHashDict_v: - self.blockHashDict_k[block_hash] = {} - self.blockHashDict_v[block_hash] = {} + if (h not in self.blockHashDict_k or + h not in self.blockHashDict_v): + self.blockHashDict_k[h] = {} + self.blockHashDict_v[h] = {} else: - self.cachedBlockHashQ.remove(block_hash) - self.blockHashDict_k[block_hash][layer_idx] = k_block_tensor.to(device="cpu", non_blocking=True) - self.blockHashDict_v[block_hash][layer_idx] = v_block_tensor.to(device="cpu", non_blocking=True) - self.cachedBlockHashQ.append(block_hash) + self.cachedBlockHashQ.remove(h) - def readCache(self, block_hash: int, layer_idx: int, device: torch.device): + self.blockHashDict_k[h][idx] = \ + k_block_tensor.to(device="cpu", non_blocking=True) + self.blockHashDict_v[h][idx] = \ + v_block_tensor.to(device="cpu", non_blocking=True) + self.cachedBlockHashQ.append(h) + + def readCache(self, h: int, idx: int, device: torch.device): if self.cachedBlockNum == 0: return - if not self.checkExist(block_hash): + if not self.checkExist(h): return - self.cachedBlockHashQ.remove(block_hash) - self.cachedBlockHashQ.append(block_hash) - return self.blockHashDict_k[block_hash][layer_idx].to(torch.device(device), non_blocking=True), self.blockHashDict_v[block_hash][layer_idx].to(torch.device(device), non_blocking=True) + self.cachedBlockHashQ.remove(h) + self.cachedBlockHashQ.append(h) + return self.blockHashDict_k[h][idx].to(device, non_blocking=True), \ + self.blockHashDict_v[h][idx].to(device, non_blocking=True) - def checkExist(self, block_hash: int): - return block_hash in self.blockHashDict_k and block_hash in self.blockHashDict_v + def checkExist(self, h: Optional[int]): + return h in self.blockHashDict_k and h in self.blockHashDict_v global_cache_instance = GlobalCache(max_mem_util=0.8) \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0aa28181c4997..52fda671de0f9 100755 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -539,7 +539,12 @@ def _compute_for_prefix_cache_hit( remaining blocks. """ computed_block_nums = inter_data.computed_block_nums - block_global_computed_tables = inter_data.block_global_computed_tables[inter_data.seq_ids[seq_idx]] if inter_data.block_global_computed_tables is not None and len(inter_data.block_global_computed_tables) > 0 else [] + block_global_computed_tables = ( + inter_data.block_global_computed_tables[ + inter_data.seq_ids[seq_idx]] + if inter_data.block_global_computed_tables is not None and + len(inter_data.block_global_computed_tables) > 0 else [] + ) # Note that prefix caching does not support sliding window. prefix_cache_hit = (computed_block_nums is not None @@ -552,21 +557,31 @@ def _compute_for_prefix_cache_hit( and self.sliding_window is None and inter_data.is_prompt) - inter_data.prefix_cache_hit = prefix_cache_hit or global_prefix_cache_hit + inter_data.prefix_cache_hit = (prefix_cache_hit or + global_prefix_cache_hit) if not inter_data.prefix_cache_hit: return - assert computed_block_nums is not None or block_global_computed_tables is not None + assert (computed_block_nums is not None or + block_global_computed_tables is not None) - computed_max_len = max(len(computed_block_nums), len(block_global_computed_tables)) + computed_max_len = max(len(computed_block_nums), + len(block_global_computed_tables)) if len(block_global_computed_tables) > len(computed_block_nums): - block_global_computed = [block for block in block_global_computed_tables if block not in computed_block_nums] + block_global_computed = [ + block for block in block_global_computed_tables + if block not in computed_block_nums] for block_id in block_global_computed: - content_hash = inter_data.block_hash_map[inter_data.seq_ids[seq_idx]][block_id] - for layer_idx in range(self.runner.model_config.hf_config.num_hidden_layers): - self.kv_caches[layer_idx][0][block_id], self.kv_caches[layer_idx][1][block_id] = global_cache_instance.readCache(content_hash, layer_idx, self.runner.device) - torch.cuda.synchronize() + content_hash = inter_data.block_hash_map[ + inter_data.seq_ids[seq_idx]][block_id] + for layer_idx in range( + self.runner.model_config.hf_config.num_hidden_layers): + (self.kv_caches[layer_idx][0][block_id], + self.kv_caches[layer_idx][1][block_id]) = \ + global_cache_instance.readCache( + content_hash, layer_idx, self.runner.device) + torch.cuda.synchronize() # The cache hit prompt tokens in this sequence. Note that # this may be larger than the sequence length if chunked @@ -1247,7 +1262,8 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids, kv_caches) + builder = self._builder_cls(weakref.proxy(self), + finished_requests_ids, kv_caches) for seq_group_metadata in seq_group_metadata_list: builder.add_seq_group(seq_group_metadata) @@ -1722,7 +1738,8 @@ def execute_model( device=self.device), **seqlen_agnostic_kwargs) - if model_input.is_prompt and self.cache_config.num_global_cache_blocks > 0: + if (model_input.is_prompt and + self.cache_config.num_global_cache_blocks > 0): self.write_global_cache(model_input, kv_caches) if (self.observability_config is not None @@ -1809,22 +1826,35 @@ def execute_model( return [output] - def write_global_cache(self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor]): + def write_global_cache( + self, model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor]): """ - for each layer and seq, get the block id and block hash, then write to global kv cache + for each layer and seq, get the block id and block hash, + then write to global kv cache """ + metadata = model_input.attn_metadata for i in range(self.model_config.hf_config.num_hidden_layers): - seq_start_index = 0 - if len(model_input.attn_metadata.block_hash_map) > 0 and model_input.attn_metadata.block_tables.numel() == 0: - for seq_idx, seq_length in enumerate(model_input.attn_metadata.seq_lens): + seq_start_idx = 0 + if (len(metadata.block_hash_map) > 0 and + metadata.block_tables.numel() == 0): + for seq_idx, seq_length in enumerate(metadata.seq_lens): num_blocks = seq_length // self.cache_config.block_size for idx in range(num_blocks): - block_id = model_input.attn_metadata.slot_mapping[seq_start_index + idx * self.cache_config.block_size].item() // self.cache_config.block_size - if block_id in model_input.attn_metadata.block_hash_map[seq_idx] and model_input.attn_metadata.block_hash_map[seq_idx][block_id] is not None: - block_hash = model_input.attn_metadata.block_hash_map[seq_idx][block_id] - global_cache_instance.writeCache(block_hash, i, kv_caches[i][0][block_id], kv_caches[i][1][block_id]) - - seq_start_index += seq_length + block_id = metadata.slot_mapping[ + seq_start_idx + idx * self.cache_config.block_size + ].item() // self.cache_config.block_size + if (block_id in metadata.block_hash_map[seq_idx] and + metadata.block_hash_map[ + seq_idx][block_id] is not None): + block_hash = metadata.block_hash_map[ + seq_idx][block_id] + global_cache_instance.writeCache( + block_hash, i, + kv_caches[i][0][block_id], + kv_caches[i][1][block_id]) + + seq_start_idx += seq_length def need_recv_kv(self, model_input, kv_caches) -> bool: """Check if we need to receive kv-cache from the other worker. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py old mode 100644 new mode 100755 index c33666e84ad72..756bf3db96a04 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -258,7 +258,8 @@ def _get_worker_input_from_broadcast( return model_input, worker_input, kwargs def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest, kv_caches: Optional[List[torch.Tensor]] = [], + self, execute_model_req: ExecuteModelRequest, + kv_caches: Optional[List[torch.Tensor]] = [], ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker @@ -306,7 +307,8 @@ def prepare_input( # notify all other workers to stop their execution loop. broadcast_tensor_dict({}, src=0) return None - return self._get_driver_input_and_broadcast(execute_model_req, kv_caches) + return self._get_driver_input_and_broadcast(execute_model_req, + kv_caches) else: return self._get_worker_input_from_broadcast() @@ -318,7 +320,8 @@ def execute_model( sequences are provided.""" start_time = time.perf_counter() - inputs = self.prepare_input(execute_model_req, kv_caches=self.kv_cache[execute_model_req.virtual_engine]) + inputs = self.prepare_input(execute_model_req, + kv_caches=self.kv_cache[execute_model_req.virtual_engine]) if inputs is None: return None