Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] clean up cudagraph batchsize padding logic #10996

Merged
merged 34 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
81d93c4
draft
youkaichao Dec 8, 2024
a69adbc
fix cudagraph logic
youkaichao Dec 8, 2024
f2db1d0
add max_capture_size
youkaichao Dec 8, 2024
08b6dd4
add _MAX_BATCH_SIZE_TO_CAPTURE
youkaichao Dec 9, 2024
3a1501a
remove dead code
youkaichao Dec 9, 2024
6907008
fix
youkaichao Dec 9, 2024
42c9300
fix
youkaichao Dec 9, 2024
4c8adcb
fix
youkaichao Dec 9, 2024
69f8ff2
fix
youkaichao Dec 9, 2024
830e34d
hide some details in string form
youkaichao Dec 9, 2024
a8f3ef1
hide some details in string form
youkaichao Dec 9, 2024
8edddfd
hide some details in string form
youkaichao Dec 9, 2024
2f7a17d
hide some details in string form
youkaichao Dec 9, 2024
7379c67
fix pydantic
youkaichao Dec 9, 2024
1c8067c
fix enforce eager
youkaichao Dec 9, 2024
59ea38f
Merge branch 'main' into cudagraph_sizes
youkaichao Dec 12, 2024
ec2d484
fix merge
youkaichao Dec 12, 2024
224c6f2
rename to pad_for_cudagraph
youkaichao Dec 12, 2024
957a72e
remove mamba import
youkaichao Dec 12, 2024
d3c3bdc
unify one function
youkaichao Dec 12, 2024
f952d18
fix
youkaichao Dec 12, 2024
e92559c
fix
youkaichao Dec 12, 2024
a02b8f1
fix mamba
youkaichao Dec 12, 2024
24b548a
use list
youkaichao Dec 12, 2024
4ba82c0
comment
youkaichao Dec 12, 2024
859e3ee
remove comments
youkaichao Dec 12, 2024
ce0cff9
comments
youkaichao Dec 12, 2024
3f750a2
comments
youkaichao Dec 12, 2024
56512ba
comments
youkaichao Dec 12, 2024
5d6928a
fix
youkaichao Dec 12, 2024
f4a7a77
remove if
youkaichao Dec 13, 2024
07d77a1
fix jamba
youkaichao Dec 13, 2024
c454391
fix mamba
youkaichao Dec 13, 2024
74f69b6
fix both
youkaichao Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/models/decoder_only/language/test_jamba.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from tests.utils import multi_gpu_test
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams

from ...utils import check_outputs_equal
Expand Down Expand Up @@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size(
vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])

Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams

from ...utils import check_outputs_equal
Expand Down Expand Up @@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size(
vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])

Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Expand Down Expand Up @@ -548,7 +547,8 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
graph_batch_size = model_runner.vllm_config.pad_for_cudagraph(
expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list(
Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -177,7 +176,8 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)

expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.num_prefills == 0
Expand Down
169 changes: 103 additions & 66 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2354,6 +2354,12 @@ def model_post_init(self, __context: Any) -> None:
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
# since we know all keys are in a range [0, max_capture_size],
# we can optimize it to List[int] for better lookup performance.
bs_to_padded_graph_size: List[int] = PrivateAttr

# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
Expand All @@ -2365,6 +2371,19 @@ def model_post_init(self, __context: Any) -> None:
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr

def __repr__(self) -> str:
exclude = {
"static_forward_context",
"enabled_custom_ops",
"disabled_custom_ops",
"compilation_time",
"bs_to_padded_graph_size",
"pass_config",
}
return self.model_dump_json(exclude=exclude, exclude_unset=True)

__str__ = __repr__

@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
Expand Down Expand Up @@ -2450,18 +2469,22 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):

# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)
self.max_capture_size = self.capture_sizes[
0] if self.capture_sizes else 0


_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_capture_size + 1)
]
for end, start in zip(self.capture_sizes,
self.capture_sizes[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
self.bs_to_padded_graph_size[
self.max_capture_size] = self.max_capture_size


@dataclass
Expand Down Expand Up @@ -2491,40 +2514,10 @@ class VllmConfig:
init=True) # type: ignore
instance_id: str = ""

@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.

Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)

@staticmethod
def get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.

pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.

if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]
def pad_for_cudagraph(self, batch_size: int) -> int:
if batch_size > self.compilation_config.max_capture_size:
return self.compilation_config.max_capture_size
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
return self.compilation_config.bs_to_padded_graph_size[batch_size]

@staticmethod
def _get_quantization_config(
Expand Down Expand Up @@ -2618,27 +2611,7 @@ def __post_init__(self):
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE

if not envs.VLLM_USE_V1:
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
max_batchsize_to_capture = \
self.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
size for size in _BATCH_SIZES_TO_CAPTURE
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
self._set_cudagraph_sizes()

if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \
Expand All @@ -2659,6 +2632,70 @@ def __post_init__(self):
if not self.instance_id:
self.instance_id = random_uuid()[:5]

def _set_cudagraph_sizes(self):
"""
cudagraph batchsize padding logic:

`[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible
batch sizes that cudagraph will capture.

Depending on the engine's configuration of `max_num_seqs`, the
candidate batch sizes to capture cudagraph will shrink to the subset
which just cover the range of `[1, max_num_seqs]`. In the common case,
`max_num_seqs` is 256, and the cudagraph batch sizes will be
`[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`.

However, if users specify the cudagraph capture sizes through
compilation config, we will use the specified sizes instead.

In the end, `vllm_config.compilation_config.capture_sizes` will be the
final sizes to capture cudagraph (in descending order).

During runtime, if batchsize is larger than
`vllm_config.compilation_config.capture_sizes`,
no cudagraph will be used.
If the batch size is no larger than
`vllm_config.compilation_config.capture_sizes`,
we can quickly find the padded graph size for a given batch size by
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
"""

# calculate the default `batch_size_capture_list`
if not envs.VLLM_USE_V1:
batch_size_capture_list = []
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:

possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
# find the minimum size that is larger than max_num_seqs,
# which then becomes the max_batchsize_to_capture
larger_sizes = [
x for x in possible_sizes
if x >= self.scheduler_config.max_num_seqs
]
if larger_sizes:
max_batchsize_to_capture = larger_sizes[0]
else:
max_batchsize_to_capture = possible_sizes[-1]

# filter out the sizes that are
# larger than max_batchsize_to_capture
batch_size_capture_list = [
size for size in possible_sizes
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)

def __str__(self):
return (
f"model={self.model_config.model!r},"
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -420,6 +420,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.max_batch_size = (vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
if self.scheduler_config else 8192 + 2)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth do you know when self.scheduler_config will be None here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't. @mzusman do you know?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be None, just added it as a safety guard - It was added before the whole vllm_config was available inside the modeling file

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! can you try to remove it then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, will open a PR shortly


def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
Expand All @@ -433,15 +436,12 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())
(
mamba_cache_tensors,
state_indices_tensor,
Expand Down
13 changes: 6 additions & 7 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import MambaConfig

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -195,6 +195,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)
self.max_batch_size = (vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
if self.scheduler_config else 8192 + 2)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)
Expand All @@ -208,15 +211,11 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())

(
mamba_cache_tensors,
Expand Down
11 changes: 2 additions & 9 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -459,7 +459,7 @@ def execute_model(
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self._get_padded_batch_size(
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
Expand Down Expand Up @@ -641,10 +641,3 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))

def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# TODO: Optimize this?
for size in self.cudagraph_batch_sizes:
if batch_size <= size:
return size
return None
2 changes: 1 addition & 1 deletion vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def _prepare_encoder_model_input_tensors(
# We will be using CUDA graph replay for this decode.
max_len_of_block_table = self.get_max_block_per_batch()
batch_size = len(encoder_seq_lens)
graph_batch_size = self.vllm_config.get_graph_batch_size(
graph_batch_size = self.vllm_config.pad_for_cudagraph(
batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
Expand Down
Loading
Loading