Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support global prefix caching #11385

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions examples/offline_inference_with_global_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory

# NOTE: This is just a running example. For benchmarking purpose,
# please see benchmarks/benchmark_prefix_caching.py

# Common prefix.
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

generating_prompts = [prefix + prompt for prompt in prompts]

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)

# Create an LLM without prefix caching as a baseline.
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)

print("Results without `enable_prefix_caching`")

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = regular_llm.generate(generating_prompts, sampling_params)

regular_generated_texts = []
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
regular_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Destroy the LLM object and free up the GPU memory.
del regular_llm
cleanup_dist_env_and_memory()

# Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
num_global_cache_blocks=5000,
gpu_memory_utilization=0.4)

# Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params)

# Generate with prefix caching.
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)

print("Results with `enable_prefix_caching`")

cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Compare the results and display the speedup
generated_same = all([
regular_generated_texts[i] == cached_generated_texts[i]
for i in range(len(prompts))
])
print(f"Generated answers are the same: {generated_same}")

"""
We can simulate the global prefix cache this way:
1. Run the vllm instance with APC for some time, so some prompts may not
hit APC as they are old and evicted
2. Delete the first vllm instance and start a new one. In this case,
global kv cache can be hit directly
Here we demo the second option.
"""
# Destroy the LLM object and free up the GPU memory.
del prefix_cached_llm
cleanup_dist_env_and_memory()

# Create an LLM with global prefix caching enabled.
global_prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
num_global_cache_blocks=5000,
gpu_memory_utilization=0.4)

# Generate with global prefix caching.
outputs = global_prefix_cached_llm.generate(generating_prompts, sampling_params)

print("Results with `enable_global_prefix_caching`")

global_cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
global_cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Compare the results and display the speedup
generated_same = all([
regular_generated_texts[i] == global_cached_generated_texts[i]
for i in range(len(prompts))
])
print(f"Generated answers are the same: {generated_same}")
7 changes: 7 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class FlashAttentionMetadata(AttentionMetadata):
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
block_hash_map: Optional[List[Dict[int, int]]]

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
Expand Down Expand Up @@ -234,6 +235,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
Expand Down Expand Up @@ -284,6 +286,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
Expand Down Expand Up @@ -376,6 +379,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.block_hash_map: List[Dict[int, int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
Expand Down Expand Up @@ -440,6 +444,8 @@ def _add_seq_group(
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
if seq_id in inter_data.block_hash_map:
self.block_hash_map.append(inter_data.block_hash_map[seq_id])

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
Expand Down Expand Up @@ -559,6 +565,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=use_captured_graph,
)

Expand Down
5 changes: 5 additions & 0 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.block_hash_map: List[Dict[int, int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
Expand Down Expand Up @@ -185,6 +186,8 @@ def _add_seq_group(
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
if seq_id in inter_data.block_hash_map:
self.block_hash_map.append(inter_data.block_hash_map[seq_id])

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
Expand Down Expand Up @@ -275,6 +278,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=use_captured_graph,
)

Expand Down Expand Up @@ -326,6 +330,7 @@ def graph_capture_get_metadata_for_batch(
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
block_hash_map={},
use_cuda_graph=True,
)
if is_encoder_decoder_model:
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
Expand Down Expand Up @@ -263,6 +264,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict

import torch

Expand Down Expand Up @@ -28,6 +28,7 @@ class PagedAttentionMetadata:
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
block_hash_map: Optional[List[Dict[int, int]]]


class PagedAttention:
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ def __init__(
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
num_global_cache_blocks: int = 0,
cpu_offload_gb: float = 0,
) -> None:
self.block_size = block_size
Expand All @@ -896,6 +897,7 @@ def __init__(
self.is_attention_free = is_attention_free
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.num_global_cache_blocks = num_global_cache_blocks
self.cpu_offload_gb = cpu_offload_gb

self._verify_args()
Expand Down
10 changes: 9 additions & 1 deletion vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Optional
from typing import List, Optional, Dict

from vllm.core.block.common import BlockList
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
Expand Down Expand Up @@ -257,6 +257,14 @@ def physical_block_ids(self) -> List[int]:
"""
return self._blocks.ids()

@property
def block_hashes(self) -> Dict[int, int]:
return self._blocks.hashes()

@property
def global_computed_list(self) -> List[int]:
return self._blocks.global_computed_list()

def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
"""Get the number of "unseen" tokens in the sequence.

Expand Down
48 changes: 48 additions & 0 deletions vllm/core/block/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,25 +235,58 @@
def __init__(self, blocks: List[Block]):
self._blocks: List[Block] = []
self._block_ids: List[int] = []
self._block_hashes: Dict[int, int] = {}
self._block_glb_computed: List[int] = []

self.update(blocks)

def _add_block_id(self, block_id: Optional[BlockId]) -> None:
assert block_id is not None
self._block_ids.append(block_id)

def _add_block_hash(self,
block_id: Optional[BlockId],
block_hash: Optional[int]) -> None:
assert block_id is not None
self._block_hashes[block_id] = block_hash

Check failure on line 251 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible types in assignment (expression has type "Optional[int]", target has type "int") [assignment]

Check failure on line 251 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible types in assignment (expression has type "int | None", target has type "int") [assignment]

Check failure on line 251 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible types in assignment (expression has type "int | None", target has type "int") [assignment]

Check failure on line 251 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible types in assignment (expression has type "int | None", target has type "int") [assignment]

def _add_block_glb_computed(self,
block_id: Optional[BlockId]) -> None:
assert block_id is not None
self._block_glb_computed.append(block_id)

def _update_block_id(self, block_index: int,
new_block_id: Optional[BlockId]) -> None:
assert new_block_id is not None
self._block_ids[block_index] = new_block_id

def _update_block_hash(self,
prev_block_id: Optional[BlockId],
new_block_id: Optional[BlockId],
new_block_hash: Optional[int]) -> None:
assert new_block_id is not None
del self._block_hashes[prev_block_id]

Check failure on line 268 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Argument 1 to "__delitem__" of "dict" has incompatible type "Optional[int]"; expected "int" [arg-type]

Check failure on line 268 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Argument 1 to "__delitem__" of "dict" has incompatible type "int | None"; expected "int" [arg-type]

Check failure on line 268 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Argument 1 to "__delitem__" of "dict" has incompatible type "int | None"; expected "int" [arg-type]

Check failure on line 268 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Argument 1 to "__delitem__" of "dict" has incompatible type "int | None"; expected "int" [arg-type]
self._block_hashes[new_block_id] = new_block_hash

Check failure on line 269 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible types in assignment (expression has type "Optional[int]", target has type "int") [assignment]

Check failure on line 269 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible types in assignment (expression has type "int | None", target has type "int") [assignment]

Check failure on line 269 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible types in assignment (expression has type "int | None", target has type "int") [assignment]

Check failure on line 269 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible types in assignment (expression has type "int | None", target has type "int") [assignment]

def _update_block_glb_computed(self,
prev_block_id: Optional[BlockId],
new_block_id: Optional[BlockId]) -> None:
assert new_block_id is not None
for idx, block_id in enumerate(self._block_glb_computed):
if block_id == prev_block_id:
self._block_glb_computed[idx] = new_block_id
break

def update(self, blocks: List[Block]):
self._blocks = blocks

# Cache block ids for fast query
self._block_ids = []
for block in self._blocks:
self._add_block_id(block.block_id)
self._add_block_hash(block.block_id, block.content_hash)
if block.global_computed:
self._add_block_glb_computed(block.block_id)

def append_token_ids(self, block_index: int, token_ids: List[int]) -> None:
block = self._blocks[block_index]
Expand All @@ -264,10 +297,17 @@
# CoW or promotion may update the internal block_id
if prev_block_id != block.block_id:
self._update_block_id(block_index, block.block_id)
self._update_block_hash(
prev_block_id, block.block_id, block.content_hash)
if block.global_computed:
self._update_block_glb_computed(prev_block_id, block.block_id)

def append(self, new_block: Block):
self._blocks.append(new_block)
self._add_block_id(new_block.block_id)
self._add_block_hash(new_block.block_id, new_block.content_hash)
if new_block.global_computed:
self._add_block_glb_computed(new_block.block_id)

def __len__(self) -> int:
return len(self._blocks)
Expand All @@ -282,13 +322,21 @@
def reset(self):
self._blocks = []
self._block_ids = []
self._block_hashes = []

Check failure on line 325 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible types in assignment (expression has type "list[Never]", variable has type "dict[int, int]") [assignment]

Check failure on line 325 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible types in assignment (expression has type "list[Never]", variable has type "dict[int, int]") [assignment]

Check failure on line 325 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible types in assignment (expression has type "list[Never]", variable has type "dict[int, int]") [assignment]

Check failure on line 325 in vllm/core/block/common.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible types in assignment (expression has type "list[Never]", variable has type "dict[int, int]") [assignment]
self._block_glb_computed = []

def list(self) -> List[Block]:
return self._blocks

def ids(self) -> List[int]:
return self._block_ids

def hashes(self) -> Dict[int, int]:
return self._block_hashes

def global_computed_list(self) -> List[int]:
return self._block_glb_computed


@dataclass
class CacheMetricData:
Expand Down
8 changes: 8 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ def computed(self):
def computed(self, value):
self._proxy.computed = value

@property
def global_computed(self):
return self._proxy.global_computed

@global_computed.setter
def global_computed(self, value):
self._proxy.global_computed = value

@property
def last_accessed(self) -> float:
return self._proxy.last_accessed
Expand Down
Loading
Loading