Skip to content

Commit

Permalink
[fix] fix CI warnings
Browse files Browse the repository at this point in the history
Signed-off-by: Jingxin Pan <[email protected]>
  • Loading branch information
lyppg committed Dec 30, 2024
1 parent 289c544 commit e28cb94
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 53 deletions.
6 changes: 4 additions & 2 deletions examples/offline_inference_with_global_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@

"""
We can simulate the global prefix cache this way:
1. Run the vllm instance with APC for some time, so some prompts may not hit APC as they are old and evicted
2. Delete the first vllm instance and start a new one. In this case, global kv cache can be hit directly
1. Run the vllm instance with APC for some time, so some prompts may not
hit APC as they are old and evicted
2. Delete the first vllm instance and start a new one. In this case,
global kv cache can be hit directly
Here we demo the second option.
"""
# Destroy the LLM object and free up the GPU memory.
Expand Down
8 changes: 8 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ def computed(self):
def computed(self, value):
self._proxy.computed = value

@property
def global_computed(self):
return self._proxy.global_computed

@global_computed.setter
def global_computed(self, value):
self._proxy.global_computed = value

@property
def last_accessed(self) -> float:
return self._proxy.last_accessed
Expand Down
10 changes: 10 additions & 0 deletions vllm/core/block/interfaces.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def computed(self, value) -> bool:
"""Should be only used by PrefixCacingAllocator"""
raise NotImplementedError

@property
@abstractmethod
def global_computed(self) -> bool:
raise NotImplementedError

@global_computed.setter
@abstractmethod
def global_computed(self, value) -> bool:
raise NotImplementedError

@property
@abstractmethod
def last_accessed(self) -> float:
Expand Down
8 changes: 8 additions & 0 deletions vllm/core/block/naive_block.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,14 @@ def computed(self) -> bool:
def computed(self, value) -> None:
raise NotImplementedError

@property
def global_computed(self) -> bool:
raise NotImplementedError

@global_computed.setter
def global_computed(self, value) -> None:
raise NotImplementedError

@property
def last_accessed(self) -> float:
raise NotImplementedError
Expand Down
13 changes: 8 additions & 5 deletions vllm/core/scheduler.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1336,19 +1336,22 @@ def schedule(
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)

if self.cache_config.enable_prefix_caching and self.cache_config.num_global_cache_blocks > 0 and seq_group.is_prefill():
if (self.cache_config.enable_prefix_caching and
self.cache_config.num_global_cache_blocks > 0 and
seq_group.is_prefill()):
global_computed_list = []
block_hash_dict = {}
block_h_dict = {}
for block_id in block_tables[seq_id]:
for block in self.block_manager.block_tables[seq_id].blocks:
for block in \
self.block_manager.block_tables[seq_id].blocks:
if block.block_id == block_id:
if block.global_computed:
global_computed_list.append(block_id)
if block.content_hash is not None:
block_hash_dict[block_id] = block.content_hash
block_h_dict[block_id] = block.content_hash
break
block_global_computed_tables[seq_id] = global_computed_list
block_hash_map[seq_id] = block_hash_dict
block_hash_map[seq_id] = block_h_dict

self.block_manager.access_all_blocks_in_seq(seq, now)

Expand Down
8 changes: 6 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,12 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
if self.model_config.use_async_output_proc else None)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
if self.cache_config.enable_prefix_caching and self.cache_config.num_global_cache_blocks > 0:
global_cache_instance.setGlabalCacheBlockNum(self.model_config, self.cache_config, self.model_executor.driver_worker.cache_engine[0].dtype)
if (self.cache_config.enable_prefix_caching and
self.cache_config.num_global_cache_blocks > 0):
global_cache_instance.setGlabalCacheBlockNum(
self.model_config,
self.cache_config,
self.model_executor.driver_worker.cache_engine[0].dtype)

# Metric Logging.
if self.log_stats:
Expand Down
65 changes: 41 additions & 24 deletions vllm/global_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Deque, Dict
from typing import Deque, Dict, Optional
import torch
from vllm.config import ModelConfig, CacheConfig
from vllm.logger import init_logger
Expand All @@ -9,8 +9,10 @@

class GlobalCache:
"""
For now just use a simple Dict to store golbal kv cache and a Deque to evict the least used key.
It can be easily extended and integrated with other kv cache pools that shared with other vllm instances.
For now just use a simple Dict to store golbal kv cache
and a Deque to evict the least used key.
It can be easily extended and integrated with other kv cache pools
that shared with other vllm instances.
"""
def __init__(self, max_mem_util: float):
self.cachedBlockNum: int = 0
Expand All @@ -19,7 +21,10 @@ def __init__(self, max_mem_util: float):
self.blockHashDict_v: Dict[int, Dict[int, torch.Tensor]] = {}
self.cachedBlockHashQ: Deque[int] = deque()

def setGlabalCacheBlockNum(self, model_config: ModelConfig, cache_config: CacheConfig, dtype: torch.dtype):
def setGlabalCacheBlockNum(
self, model_config: ModelConfig,
cache_config: CacheConfig,
dtype: torch.dtype):
if self.cachedBlockNum > 0:
logger.warning("global kv cache already enabled")
return
Expand All @@ -28,42 +33,54 @@ def setGlabalCacheBlockNum(self, model_config: ModelConfig, cache_config: CacheC
return
available_mem = psutil.virtual_memory().available * self.max_mem_util
num_kv_heads = model_config.hf_text_config.num_attention_heads
head_size = model_config.hf_text_config.hidden_size // model_config.hf_text_config.num_attention_heads
head_size = (model_config.hf_text_config.hidden_size //
model_config.hf_text_config.num_attention_heads)
num_attention_layers = model_config.hf_config.num_hidden_layers
dtype_size = torch.tensor([], dtype=dtype).element_size()
key_size_bytes = dtype_size * cache_config.block_size * num_kv_heads * head_size * num_attention_layers
if key_size_bytes * 2 * cache_config.num_global_cache_blocks > available_mem:
logger.warning("num_global_cache_blocks too large, can not enable global kv cache, at most %d blocks can be used", available_mem // (key_size_bytes * 2))
key_size_bytes = (dtype_size * cache_config.block_size *
num_kv_heads * head_size * num_attention_layers)
if (key_size_bytes * 2 * cache_config.num_global_cache_blocks >
available_mem):
logger.warning("num_global_cache_blocks too large, can not enable "
"global kv cache, at most %d blocks can be used",
available_mem // (key_size_bytes * 2))
return
self.cachedBlockNum = cache_config.num_global_cache_blocks
logger.info("global kv cache enabled")
logger.info("global kv cache enabled: %d", self.cachedBlockNum)

def writeCache(self, block_hash: int, layer_idx: int, k_block_tensor: torch.Tensor, v_block_tensor: torch.Tensor):
def writeCache(
self, h: int, idx: int,
k_block_tensor: torch.Tensor, v_block_tensor: torch.Tensor):
if self.cachedBlockNum == 0:
return
if len(self.cachedBlockHashQ) == self.cachedBlockNum:
poped_block_hash = self.cachedBlockHashQ.popleft()
del self.blockHashDict_k[poped_block_hash]
del self.blockHashDict_v[poped_block_hash]
if block_hash not in self.blockHashDict_k or block_hash not in self.blockHashDict_v:
self.blockHashDict_k[block_hash] = {}
self.blockHashDict_v[block_hash] = {}
if (h not in self.blockHashDict_k or
h not in self.blockHashDict_v):
self.blockHashDict_k[h] = {}
self.blockHashDict_v[h] = {}
else:
self.cachedBlockHashQ.remove(block_hash)
self.blockHashDict_k[block_hash][layer_idx] = k_block_tensor.to(device="cpu", non_blocking=True)
self.blockHashDict_v[block_hash][layer_idx] = v_block_tensor.to(device="cpu", non_blocking=True)
self.cachedBlockHashQ.append(block_hash)
self.cachedBlockHashQ.remove(h)

def readCache(self, block_hash: int, layer_idx: int, device: torch.device):
self.blockHashDict_k[h][idx] = \
k_block_tensor.to(device="cpu", non_blocking=True)
self.blockHashDict_v[h][idx] = \
v_block_tensor.to(device="cpu", non_blocking=True)
self.cachedBlockHashQ.append(h)

def readCache(self, h: int, idx: int, device: torch.device):
if self.cachedBlockNum == 0:
return
if not self.checkExist(block_hash):
if not self.checkExist(h):
return
self.cachedBlockHashQ.remove(block_hash)
self.cachedBlockHashQ.append(block_hash)
return self.blockHashDict_k[block_hash][layer_idx].to(torch.device(device), non_blocking=True), self.blockHashDict_v[block_hash][layer_idx].to(torch.device(device), non_blocking=True)
self.cachedBlockHashQ.remove(h)
self.cachedBlockHashQ.append(h)
return self.blockHashDict_k[h][idx].to(device, non_blocking=True), \
self.blockHashDict_v[h][idx].to(device, non_blocking=True)

def checkExist(self, block_hash: int):
return block_hash in self.blockHashDict_k and block_hash in self.blockHashDict_v
def checkExist(self, h: Optional[int]):
return h in self.blockHashDict_k and h in self.blockHashDict_v

global_cache_instance = GlobalCache(max_mem_util=0.8)
67 changes: 47 additions & 20 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,12 @@ def _compute_for_prefix_cache_hit(
remaining blocks.
"""
computed_block_nums = inter_data.computed_block_nums
block_global_computed_tables = inter_data.block_global_computed_tables[inter_data.seq_ids[seq_idx]] if inter_data.block_global_computed_tables is not None and len(inter_data.block_global_computed_tables) > 0 else []
block_global_computed_tables = (
inter_data.block_global_computed_tables[
inter_data.seq_ids[seq_idx]]
if inter_data.block_global_computed_tables is not None and
len(inter_data.block_global_computed_tables) > 0 else []
)

# Note that prefix caching does not support sliding window.
prefix_cache_hit = (computed_block_nums is not None
Expand All @@ -552,21 +557,30 @@ def _compute_for_prefix_cache_hit(
and self.sliding_window is None
and inter_data.is_prompt)

inter_data.prefix_cache_hit = prefix_cache_hit or global_prefix_cache_hit
inter_data.prefix_cache_hit = (prefix_cache_hit or
global_prefix_cache_hit)

if not inter_data.prefix_cache_hit:
return

assert computed_block_nums is not None or block_global_computed_tables is not None
assert (computed_block_nums is not None or
block_global_computed_tables is not None)

computed_max_len = max(len(computed_block_nums), len(block_global_computed_tables))
computed_max_len = max(len(computed_block_nums),
len(block_global_computed_tables))
if len(block_global_computed_tables) > len(computed_block_nums):
block_global_computed = [block for block in block_global_computed_tables if block not in computed_block_nums]
block_global_computed = [block for block in block_global_computed_tables
if block not in computed_block_nums]
for block_id in block_global_computed:

Check failure on line 574 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/model_runner.py:574:81: E501 Line too long (85 > 80)
content_hash = inter_data.block_hash_map[inter_data.seq_ids[seq_idx]][block_id]
for layer_idx in range(self.runner.model_config.hf_config.num_hidden_layers):
self.kv_caches[layer_idx][0][block_id], self.kv_caches[layer_idx][1][block_id] = global_cache_instance.readCache(content_hash, layer_idx, self.runner.device)
torch.cuda.synchronize()
content_hash = inter_data.block_hash_map[
inter_data.seq_ids[seq_idx]][block_id]
for layer_idx in range(

Check failure on line 577 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Value of type "dict[int, dict[int, int]] | None" is not indexable [index]

Check failure on line 577 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Value of type "dict[int, dict[int, int]] | None" is not indexable [index]

Check failure on line 577 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Value of type "Optional[dict[int, dict[int, int]]]" is not indexable [index]

Check failure on line 577 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Value of type "dict[int, dict[int, int]] | None" is not indexable [index]
self.runner.model_config.hf_config.num_hidden_layers):
(self.kv_caches[layer_idx][0][block_id],
self.kv_caches[layer_idx][1][block_id]) = \
global_cache_instance.readCache(

Check failure on line 581 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Value of type "list[Any] | None" is not indexable [index]

Check failure on line 581 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Value of type "list[Any] | None" is not indexable [index]

Check failure on line 581 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Value of type "Optional[list[Any]]" is not indexable [index]

Check failure on line 581 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Value of type "list[Any] | None" is not indexable [index]
content_hash, layer_idx, self.runner.device)

Check failure on line 582 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Value of type "list[Any] | None" is not indexable [index]

Check failure on line 582 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Value of type "list[Any] | None" is not indexable [index]

Check failure on line 582 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Value of type "Optional[list[Any]]" is not indexable [index]

Check failure on line 582 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Value of type "list[Any] | None" is not indexable [index]
torch.cuda.synchronize()

# The cache hit prompt tokens in this sequence. Note that
# this may be larger than the sequence length if chunked
Expand Down Expand Up @@ -1809,22 +1823,35 @@ def execute_model(

return [output]

def write_global_cache(self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor]):
def write_global_cache(
self, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor]):
"""
for each layer and seq, get the block id and block hash, then write to global kv cache
for each layer and seq, get the block id and block hash,
then write to global kv cache
"""
metadata = model_input.attn_metadata
for i in range(self.model_config.hf_config.num_hidden_layers):
seq_start_index = 0
if len(model_input.attn_metadata.block_hash_map) > 0 and model_input.attn_metadata.block_tables.numel() == 0:
for seq_idx, seq_length in enumerate(model_input.attn_metadata.seq_lens):
seq_start_idx = 0
if (len(metadata.block_hash_map) > 0 and
metadata.block_tables.numel() == 0):
for seq_idx, seq_length in enumerate(metadata.seq_lens):
num_blocks = seq_length // self.cache_config.block_size
for idx in range(num_blocks):
block_id = model_input.attn_metadata.slot_mapping[seq_start_index + idx * self.cache_config.block_size].item() // self.cache_config.block_size
if block_id in model_input.attn_metadata.block_hash_map[seq_idx] and model_input.attn_metadata.block_hash_map[seq_idx][block_id] is not None:
block_hash = model_input.attn_metadata.block_hash_map[seq_idx][block_id]
global_cache_instance.writeCache(block_hash, i, kv_caches[i][0][block_id], kv_caches[i][1][block_id])

seq_start_index += seq_length
block_id = metadata.slot_mapping[
seq_start_idx + idx * self.cache_config.block_size
].item() // self.cache_config.block_size
if (block_id in metadata.block_hash_map[seq_idx] and

Check failure on line 1844 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "block_hash_map" [union-attr]

Check failure on line 1844 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "block_hash_map" [union-attr]

Check failure on line 1844 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "block_hash_map" [union-attr]

Check failure on line 1844 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "block_hash_map" [union-attr]
metadata.block_hash_map[

Check failure on line 1845 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "block_tables" [union-attr]

Check failure on line 1845 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "block_tables" [union-attr]

Check failure on line 1845 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "block_tables" [union-attr]

Check failure on line 1845 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "block_tables" [union-attr]
seq_idx][block_id] is not None):

Check failure on line 1846 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "seq_lens" [union-attr]

Check failure on line 1846 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "seq_lens" [union-attr]

Check failure on line 1846 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "seq_lens" [union-attr]

Check failure on line 1846 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "seq_lens" [union-attr]
block_hash = metadata.block_hash_map[
seq_idx][block_id]
global_cache_instance.writeCache(

Check failure on line 1849 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "slot_mapping" [union-attr]

Check failure on line 1849 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "slot_mapping" [union-attr]

Check failure on line 1849 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "slot_mapping" [union-attr]

Check failure on line 1849 in vllm/worker/model_runner.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "slot_mapping" [union-attr]
block_hash, i,
kv_caches[i][0][block_id],
kv_caches[i][1][block_id])

seq_start_idx += seq_length

def need_recv_kv(self, model_input, kv_caches) -> bool:
"""Check if we need to receive kv-cache from the other worker.
Expand Down

0 comments on commit e28cb94

Please sign in to comment.