Skip to content

Commit

Permalink
[torch.compile] Hide KV cache behind torch.compile boundary (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#11677)

Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Fred Reiss <[email protected]>
  • Loading branch information
heheda12345 authored and frreiss committed Jan 10, 2025
1 parent 8d80490 commit 9b22ef0
Show file tree
Hide file tree
Showing 18 changed files with 198 additions and 44 deletions.
18 changes: 12 additions & 6 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,18 @@ class that Attention will automatically select when it is constructed.
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))

# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
if test_pt.attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER):
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
else:
kv_cache = torch.tensor([])

attn.kv_cache = [kv_cache]
return TestResources(scale, attn, kv_cache)


Expand Down
85 changes: 83 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import torch
from vllm_test_utils import monitor

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
StoreBoolean, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, supports_kw)
StoreBoolean, bind_kv_cache, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -325,6 +327,85 @@ def measure_current_non_torch():
lib.cudaFree(handle2)


def test_bind_kv_cache():
from vllm.attention import Attention

ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]

def test_bind_kv_cache_non_attention():
from vllm.attention import Attention

# example from Jamba PP=2
ctx = {
'model.layers.20.attn': Attention(32, 128, 0.1),
'model.layers.28.attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]


def test_bind_kv_cache_encoder_decoder():
from vllm.attention import Attention, AttentionType

# example from bart
ctx = {
'encoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
'decoder.layers.0.encoder_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
'decoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
}

kv_cache = [
torch.zeros((1, )),
]
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache

bind_kv_cache(ctx, [kv_cache])
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]


def test_bind_kv_cache_pp():
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention

ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
[torch.zeros((1, ))],
[torch.zeros((1, ))]
]
bind_kv_cache(ctx, kv_cache)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]


def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")

Expand Down
3 changes: 3 additions & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from transformers import AutoTokenizer

from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
)


@fork_new_process_for_each_test
def test_engine_core(monkeypatch):

with monkeypatch.context() as m:
Expand Down Expand Up @@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
assert len(engine_core.scheduler.running) == 0


@fork_new_process_for_each_test
def test_engine_core_advanced_sampling(monkeypatch):
"""
A basic end-to-end test to verify that the engine functions correctly
Expand Down
3 changes: 3 additions & 0 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from transformers import AutoTokenizer

from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break


@fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):

Expand Down Expand Up @@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client.abort_requests([request.request_id])


@fork_new_process_for_each_test
@pytest.mark.asyncio
async def test_engine_core_client_asyncio(monkeypatch):

Expand Down
29 changes: 17 additions & 12 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ def __init__(
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]

def forward(
self,
Expand Down Expand Up @@ -148,11 +155,11 @@ def forward(
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, self.layer_name)
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, self.layer_name)
self.layer_name)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down Expand Up @@ -230,12 +237,12 @@ def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._k_scale, self._v_scale)

Expand All @@ -244,7 +251,6 @@ def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
Expand All @@ -253,7 +259,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
Expand All @@ -264,12 +270,12 @@ def unified_attention_with_output(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(query,
key,
value,
Expand All @@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> None:
return
Expand All @@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
1 change: 0 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,7 +2780,6 @@ def model_post_init(self, __context: Any) -> None:
compilation_time: float = PrivateAttr

# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr

Expand Down
33 changes: 20 additions & 13 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

import torch

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
Expand All @@ -21,9 +24,12 @@

@dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# copy from vllm_config.compilation_config.static_forward_context
attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass


_forward_context: Optional[ForwardContext] = None
Expand All @@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:


@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and context is not None
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
attn_layers=vllm_config.compilation_config.static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata)
try:
yield
finally:
global batchsize_counter
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(context, "num_prefill_tokens"):
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + \
context.num_decode_tokens
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
Expand Down
35 changes: 35 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,3 +2138,38 @@ def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)


def bind_kv_cache(
ctx: Dict[str, Any],
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
layer_name for layer_name in ctx
if ctx[layer_name].attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER)
]
layer_index_sorted = sorted(
set(
extract_layer_index(layer_name)
for layer_name in layer_need_kv_cache))
for layer_name in layer_need_kv_cache:
kv_cache_idx = layer_index_sorted.index(
extract_layer_index(layer_name))
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
assert forward_ctx.kv_cache[ve].numel() == 0
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
Loading

0 comments on commit 9b22ef0

Please sign in to comment.