Skip to content

Commit

Permalink
[fix] fix CI warnings
Browse files Browse the repository at this point in the history
Signed-off-by: Jingxin Pan <[email protected]>
  • Loading branch information
lyppg committed Dec 30, 2024
1 parent 289c544 commit 6e58711
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 24 deletions.
6 changes: 4 additions & 2 deletions examples/offline_inference_with_global_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@

"""
We can simulate the global prefix cache this way:
1. Run the vllm instance with APC for some time, so some prompts may not hit APC as they are old and evicted
2. Delete the first vllm instance and start a new one. In this case, global kv cache can be hit directly
1. Run the vllm instance with APC for some time, so some prompts may not
hit APC as they are old and evicted
2. Delete the first vllm instance and start a new one. In this case,
global kv cache can be hit directly
Here we demo the second option.
"""
# Destroy the LLM object and free up the GPU memory.
Expand Down
8 changes: 8 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ def computed(self):
def computed(self, value):
self._proxy.computed = value

@property
def global_computed(self):
return self._proxy.global_computed

@global_computed.setter
def global_computed(self, value):
self._proxy.global_computed = value

@property
def last_accessed(self) -> float:
return self._proxy.last_accessed
Expand Down
10 changes: 10 additions & 0 deletions vllm/core/block/interfaces.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def computed(self, value) -> bool:
"""Should be only used by PrefixCacingAllocator"""
raise NotImplementedError

@property
@abstractmethod
def global_computed(self) -> bool:
raise NotImplementedError

@global_computed.setter
@abstractmethod
def global_computed(self, value) -> bool:
raise NotImplementedError

@property
@abstractmethod
def last_accessed(self) -> float:
Expand Down
8 changes: 8 additions & 0 deletions vllm/core/block/naive_block.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,14 @@ def computed(self) -> bool:
def computed(self, value) -> None:
raise NotImplementedError

@property
def global_computed(self) -> bool:
raise NotImplementedError

@global_computed.setter
def global_computed(self, value) -> None:
raise NotImplementedError

@property
def last_accessed(self) -> float:
raise NotImplementedError
Expand Down
13 changes: 8 additions & 5 deletions vllm/core/scheduler.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1336,19 +1336,22 @@ def schedule(
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)

if self.cache_config.enable_prefix_caching and self.cache_config.num_global_cache_blocks > 0 and seq_group.is_prefill():
if (self.cache_config.enable_prefix_caching and
self.cache_config.num_global_cache_blocks > 0 and
seq_group.is_prefill()):
global_computed_list = []
block_hash_dict = {}
block_h_dict = {}
for block_id in block_tables[seq_id]:
for block in self.block_manager.block_tables[seq_id].blocks:
for block in \
self.block_manager.block_tables[seq_id].blocks:
if block.block_id == block_id:
if block.global_computed:
global_computed_list.append(block_id)
if block.content_hash is not None:
block_hash_dict[block_id] = block.content_hash
block_h_dict[block_id] = block.content_hash
break
block_global_computed_tables[seq_id] = global_computed_list
block_hash_map[seq_id] = block_hash_dict
block_hash_map[seq_id] = block_h_dict

self.block_manager.access_all_blocks_in_seq(seq, now)

Expand Down
8 changes: 6 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,12 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
if self.model_config.use_async_output_proc else None)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
if self.cache_config.enable_prefix_caching and self.cache_config.num_global_cache_blocks > 0:
global_cache_instance.setGlabalCacheBlockNum(self.model_config, self.cache_config, self.model_executor.driver_worker.cache_engine[0].dtype)
if (self.cache_config.enable_prefix_caching and
self.cache_config.num_global_cache_blocks > 0):
global_cache_instance.setGlabalCacheBlockNum(
self.model_config,
self.cache_config,
self.model_executor.driver_worker.cache_engine[0].dtype)

# Metric Logging.
if self.log_stats:
Expand Down
46 changes: 31 additions & 15 deletions vllm/global_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Deque, Dict
from typing import Deque, Dict, Optional
import torch
from vllm.config import ModelConfig, CacheConfig
from vllm.logger import init_logger
Expand All @@ -9,8 +9,10 @@

class GlobalCache:
"""
For now just use a simple Dict to store golbal kv cache and a Deque to evict the least used key.
It can be easily extended and integrated with other kv cache pools that shared with other vllm instances.
For now just use a simple Dict to store golbal kv cache
and a Deque to evict the least used key.
It can be easily extended and integrated with other kv cache pools
that shared with other vllm instances.
"""
def __init__(self, max_mem_util: float):
self.cachedBlockNum: int = 0
Expand All @@ -19,7 +21,10 @@ def __init__(self, max_mem_util: float):
self.blockHashDict_v: Dict[int, Dict[int, torch.Tensor]] = {}
self.cachedBlockHashQ: Deque[int] = deque()

def setGlabalCacheBlockNum(self, model_config: ModelConfig, cache_config: CacheConfig, dtype: torch.dtype):
def setGlabalCacheBlockNum(
self, model_config: ModelConfig,
cache_config: CacheConfig,
dtype: torch.dtype):
if self.cachedBlockNum > 0:
logger.warning("global kv cache already enabled")
return
Expand All @@ -28,30 +33,41 @@ def setGlabalCacheBlockNum(self, model_config: ModelConfig, cache_config: CacheC
return
available_mem = psutil.virtual_memory().available * self.max_mem_util
num_kv_heads = model_config.hf_text_config.num_attention_heads
head_size = model_config.hf_text_config.hidden_size // model_config.hf_text_config.num_attention_heads
head_size = (model_config.hf_text_config.hidden_size //
model_config.hf_text_config.num_attention_heads)
num_attention_layers = model_config.hf_config.num_hidden_layers
dtype_size = torch.tensor([], dtype=dtype).element_size()
key_size_bytes = dtype_size * cache_config.block_size * num_kv_heads * head_size * num_attention_layers
if key_size_bytes * 2 * cache_config.num_global_cache_blocks > available_mem:
logger.warning("num_global_cache_blocks too large, can not enable global kv cache, at most %d blocks can be used", available_mem // (key_size_bytes * 2))
key_size_bytes = (dtype_size * cache_config.block_size *
num_kv_heads * head_size * num_attention_layers)
if (key_size_bytes * 2 * cache_config.num_global_cache_blocks >
available_mem):
logger.warning("num_global_cache_blocks too large, can not enable "
"global kv cache, at most %d blocks can be used",
available_mem // (key_size_bytes * 2))
return
self.cachedBlockNum = cache_config.num_global_cache_blocks
logger.info("global kv cache enabled")
logger.info("global kv cache enabled: %d", self.cachedBlockNum)

def writeCache(self, block_hash: int, layer_idx: int, k_block_tensor: torch.Tensor, v_block_tensor: torch.Tensor):
def writeCache(
self, block_hash: int, layer_idx: int,
k_block_tensor: torch.Tensor, v_block_tensor: torch.Tensor):
if self.cachedBlockNum == 0:
return
if len(self.cachedBlockHashQ) == self.cachedBlockNum:
poped_block_hash = self.cachedBlockHashQ.popleft()
del self.blockHashDict_k[poped_block_hash]
del self.blockHashDict_v[poped_block_hash]
if block_hash not in self.blockHashDict_k or block_hash not in self.blockHashDict_v:
if (block_hash not in self.blockHashDict_k or
block_hash not in self.blockHashDict_v):
self.blockHashDict_k[block_hash] = {}
self.blockHashDict_v[block_hash] = {}
else:
self.cachedBlockHashQ.remove(block_hash)
self.blockHashDict_k[block_hash][layer_idx] = k_block_tensor.to(device="cpu", non_blocking=True)
self.blockHashDict_v[block_hash][layer_idx] = v_block_tensor.to(device="cpu", non_blocking=True)

self.blockHashDict_k[block_hash][layer_idx] = \
k_block_tensor.to(device="cpu", non_blocking=True)
self.blockHashDict_v[block_hash][layer_idx] = \
v_block_tensor.to(device="cpu", non_blocking=True)
self.cachedBlockHashQ.append(block_hash)

def readCache(self, block_hash: int, layer_idx: int, device: torch.device):
Expand All @@ -61,9 +77,9 @@ def readCache(self, block_hash: int, layer_idx: int, device: torch.device):
return
self.cachedBlockHashQ.remove(block_hash)
self.cachedBlockHashQ.append(block_hash)
return self.blockHashDict_k[block_hash][layer_idx].to(torch.device(device), non_blocking=True), self.blockHashDict_v[block_hash][layer_idx].to(torch.device(device), non_blocking=True)
return self.blockHashDict_k[block_hash][layer_idx].to(device, non_blocking=True), self.blockHashDict_v[block_hash][layer_idx].to(device, non_blocking=True)

Check failure on line 80 in vllm/global_cache.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/global_cache.py:80:81: E501 Line too long (163 > 80)

def checkExist(self, block_hash: int):
def checkExist(self, block_hash: Optional[int]):
return block_hash in self.blockHashDict_k and block_hash in self.blockHashDict_v

Check failure on line 83 in vllm/global_cache.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/global_cache.py:83:81: E501 Line too long (88 > 80)

global_cache_instance = GlobalCache(max_mem_util=0.8)

0 comments on commit 6e58711

Please sign in to comment.