From 81d93c455d7e498d1743e628266f4d2e36b9e566 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 15:31:52 -0800 Subject: [PATCH 01/33] draft Signed-off-by: youkaichao --- .../decoder_only/language/test_jamba.py | 2 +- .../decoder_only/language/test_mamba.py | 2 +- .../test_encoder_decoder_model_runner.py | 2 +- tests/worker/test_model_runner.py | 3 +- vllm/config.py | 50 +++++++++---------- vllm/model_executor/models/jamba.py | 2 +- vllm/model_executor/models/mamba.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 9 +--- vllm/worker/enc_dec_model_runner.py | 2 +- vllm/worker/model_runner.py | 7 +-- 10 files changed, 38 insertions(+), 43 deletions(-) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index cae25ae9fa2c8..0797576d123e6 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -189,7 +189,7 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.get_graph_batch_size( + while len(example_prompts) == VllmConfig.static_pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 35018c3c14dee..d5d70d38414ef 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -200,7 +200,7 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.get_graph_batch_size( + while len(example_prompts) == VllmConfig.static_pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 5289c91f201cd..6afbd76e41bed 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size) + graph_batch_size = VllmConfig.static_pad_for_cudagraph(expanded_batch_size) cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 4055524f3e0c7..33190f3a66e3f 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -177,7 +177,8 @@ def test_prepare_decode_cuda_graph(batch_size): model_input.attn_metadata, model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) - expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list)) + expected_bs = VllmConfig.static_pad_for_cudagraph( + len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.num_prefills == 0 diff --git a/vllm/config.py b/vllm/config.py index 38cf642b23cda..43151df962423 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2274,6 +2274,7 @@ def model_post_init(self, __context: Any) -> None: # not configurable, computed after init compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr + bs_to_padded_graph_size: Dict[int, int] = PrivateAttr # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = PrivateAttr @@ -2371,6 +2372,12 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) + # pre-compute the mapping from batch size to padded graph size + self.bs_to_padded_graph_size = {} + for end, start in zip(self.capture_sizes, self.capture_sizes[1:]): + for bs in range(start, end): + self.bs_to_padded_graph_size[bs] = end + _BATCH_SIZE_ALIGNMENT = 8 # all the token sizes that **can** be captured by cudagraph. @@ -2378,11 +2385,15 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): # currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. # the actual sizes to capture will be determined by the model, # depending on the model's max_num_seqs. -# NOTE: get_graph_batch_size needs to be updated if this list is changed. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) ] +bs_to_padded_graph_size: Dict[int, int] = {} +for start, end in zip(_BATCH_SIZES_TO_CAPTURE, _BATCH_SIZES_TO_CAPTURE[1:]): + for bs in range(start, end): + bs_to_padded_graph_size[bs] = end + @dataclass class VllmConfig: @@ -2411,39 +2422,28 @@ class VllmConfig: init=True) # type: ignore instance_id: str = "" - @staticmethod - def get_graph_batch_size(batch_size: int) -> int: - """Returns the padded batch size given actual batch size. - - Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, - 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + def model_pad_for_cudagraph(self, batch_size: int) -> int: + """Returns the padded batch size given actual batch size, + considering the model's configuration. """ - if batch_size <= 2: - return batch_size - elif batch_size <= 4: - return 4 - else: - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + if batch_size in self.compilation_config.bs_to_padded_graph_size: + return self.compilation_config.bs_to_padded_graph_size[batch_size] + return self.compilation_config.capture_sizes[0] @staticmethod - def get_max_graph_batch_size(max_num_seqs: int) -> int: + def static_pad_for_cudagraph(batch_size: int) -> int: """ - max_num_seqs: Maximum number of sequences in a batch. - _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. - - pad the max_num_seqs if necessary by calling get_graph_batch_size, - which will deal with some edge cases like 1, 2, 4. + This function statically pads the batch size to the nearest + number in _BATCH_SIZES_TO_CAPTURE , without considering the + model's configuration. if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. if not, it means the padded size is larger than the largest size in _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. """ - padded_size = VllmConfig.get_graph_batch_size(max_num_seqs) - if padded_size in _BATCH_SIZES_TO_CAPTURE: - return padded_size - assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] + if batch_size in bs_to_padded_graph_size: + return bs_to_padded_graph_size[batch_size] return _BATCH_SIZES_TO_CAPTURE[-1] @staticmethod @@ -2543,7 +2543,7 @@ def __post_init__(self): self.model_config is not None and \ not self.model_config.enforce_eager: max_batchsize_to_capture = \ - self.get_max_graph_batch_size( + self.static_pad_for_cudagraph( self.scheduler_config.max_num_seqs) batch_size_capture_list = [ size for size in _BATCH_SIZES_TO_CAPTURE diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5d5e8ae1ee532..2dc7027e01e7a 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -402,7 +402,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.get_graph_batch_size( + max_batch_size = (VllmConfig.static_pad_for_cudagraph( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index b32032e411b0a..25564e11e6812 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -185,7 +185,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.get_graph_batch_size( + max_batch_size = (VllmConfig.static_pad_for_cudagraph( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) self.mamba_cache = MambaCacheManager( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8d964a722f60..37fbaf04d7be5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -440,7 +440,7 @@ def execute_model( and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_input_tokens = self._get_padded_batch_size( + num_input_tokens = self.vllm_config.model_pad_for_cudagraph( num_scheduled_tokens) else: # Eager mode. @@ -603,13 +603,6 @@ def initialize_kv_cache(self, num_blocks: int) -> None: dtype=self.kv_cache_dtype, device=self.device)) - def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: - # TODO: Optimize this? - for size in self.cudagraph_batch_sizes: - if batch_size <= size: - return size - return None - @dataclass class CachedRequestState: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 5697fbbaa2041..016deeb24eac7 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -464,7 +464,7 @@ def _prepare_encoder_model_input_tensors( # We will be using CUDA graph replay for this decode. max_len_of_block_table = self.get_max_block_per_batch() batch_size = len(encoder_seq_lens) - graph_batch_size = self.vllm_config.get_graph_batch_size( + graph_batch_size = self.vllm_config.model_pad_for_cudagraph( batch_size) assert graph_batch_size >= batch_size cuda_graph_pad_size = graph_batch_size - batch_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1bc5f65c7127f..6efd67e622bce 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -800,7 +800,8 @@ def _get_cuda_graph_pad_size(self, max_encoder_seq_len): return -1 - graph_batch_size = VllmConfig.get_graph_batch_size(batch_size) + graph_batch_size = self.runner.vllm_config.model_pad_for_cudagraph( + batch_size) assert graph_batch_size >= batch_size return graph_batch_size - batch_size @@ -1012,8 +1013,8 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size( - self.scheduler_config.max_num_seqs) + self.max_batchsize_to_capture = \ + self.vllm_config.compilation_config.capture_sizes[0] self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) From a69adbc021b3235bee36c19b1ab0657b164d5e1d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 15:51:05 -0800 Subject: [PATCH 02/33] fix cudagraph logic Signed-off-by: youkaichao --- vllm/config.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 43151df962423..0fca81282875e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2374,11 +2374,32 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = {} - for end, start in zip(self.capture_sizes, self.capture_sizes[1:]): + for end, start in zip(self.capture_sizes, + self.capture_sizes[1:] + [0]): for bs in range(start, end): self.bs_to_padded_graph_size[bs] = end +""" +cudagraph batchsize padding logic: + +In the default case, `_BATCH_SIZES_TO_CAPTURE` is a list of all possible +batch sizes that cudagraph will capture. We pre-build a mapping from batch size +to padded graph size, so that we can quickly find the padded graph size for a +given batch size. Depending on the model's configuration, like `max_num_seqs`, +the candidate batch sizes to capture cudagraph will shrink to the subset of +`_BATCH_SIZES_TO_CAPTURE` that is less than or equal to `max_num_seqs`. + +However, if users specify the cudagraph capture sizes through compilation +config, we will use the specified sizes instead. + +In the end, `vllm_config.compilation_config.capture_sizes` will be the final +sizes to capture cudagraph (in descending order), and +`vllm_config.compilation_config.bs_to_padded_graph_size` will be the mapping +from batch size to padded graph size, if the batch size is less than or equal to +the largest size in `vllm_config.compilation_config.capture_sizes`. +""" + _BATCH_SIZE_ALIGNMENT = 8 # all the token sizes that **can** be captured by cudagraph. # they can be arbitrarily large. @@ -2390,7 +2411,8 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): ] bs_to_padded_graph_size: Dict[int, int] = {} -for start, end in zip(_BATCH_SIZES_TO_CAPTURE, _BATCH_SIZES_TO_CAPTURE[1:]): +for start, end in zip([0] + _BATCH_SIZES_TO_CAPTURE[:-1], + _BATCH_SIZES_TO_CAPTURE): for bs in range(start, end): bs_to_padded_graph_size[bs] = end From f2db1d0f3fcb9d23be8d6817470e711c258972db Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 15:58:55 -0800 Subject: [PATCH 03/33] add max_capture_size Signed-off-by: youkaichao --- vllm/config.py | 9 ++++++--- vllm/worker/model_runner.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 0fca81282875e..26e1287dc3ebe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2274,6 +2274,7 @@ def model_post_init(self, __context: Any) -> None: # not configurable, computed after init compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr + max_capture_size: int = PrivateAttr bs_to_padded_graph_size: Dict[int, int] = PrivateAttr # keep track of enabled and disabled custom ops @@ -2379,6 +2380,8 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): for bs in range(start, end): self.bs_to_padded_graph_size[bs] = end + self.max_capture_size = self.capture_sizes[0] + """ cudagraph batchsize padding logic: @@ -2448,9 +2451,9 @@ def model_pad_for_cudagraph(self, batch_size: int) -> int: """Returns the padded batch size given actual batch size, considering the model's configuration. """ - if batch_size in self.compilation_config.bs_to_padded_graph_size: - return self.compilation_config.bs_to_padded_graph_size[batch_size] - return self.compilation_config.capture_sizes[0] + if batch_size > self.compilation_config.max_capture_size: + return self.compilation_config.max_capture_size + return self.compilation_config.bs_to_padded_graph_size[batch_size] @staticmethod def static_pad_for_cudagraph(batch_size: int) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6efd67e622bce..197456553e31f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1014,7 +1014,7 @@ def __init__( self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.max_batchsize_to_capture = \ - self.vllm_config.compilation_config.capture_sizes[0] + self.vllm_config.compilation_config.max_capture_size self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) From 08b6dd48fea75843faa33eb3f19b9d8841936d22 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:01:49 -0800 Subject: [PATCH 04/33] add _MAX_BATCH_SIZE_TO_CAPTURE Signed-off-by: youkaichao --- vllm/config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 26e1287dc3ebe..a556aa12eb702 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2419,6 +2419,8 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): for bs in range(start, end): bs_to_padded_graph_size[bs] = end +_MAX_BATCH_SIZE_TO_CAPTURE = _BATCH_SIZES_TO_CAPTURE[-1] + @dataclass class VllmConfig: @@ -2467,9 +2469,9 @@ def static_pad_for_cudagraph(batch_size: int) -> int: in _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. """ - if batch_size in bs_to_padded_graph_size: - return bs_to_padded_graph_size[batch_size] - return _BATCH_SIZES_TO_CAPTURE[-1] + if batch_size > _MAX_BATCH_SIZE_TO_CAPTURE: + return _MAX_BATCH_SIZE_TO_CAPTURE + return bs_to_padded_graph_size[batch_size] @staticmethod def _get_quantization_config( From 3a1501ad9d02b1e1b2da13a6c3c5aa0d2a807cf6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:04:17 -0800 Subject: [PATCH 05/33] remove dead code Signed-off-by: youkaichao --- vllm/model_executor/models/jamba.py | 4 ++-- vllm/model_executor/models/mamba.py | 4 ++-- vllm/worker/xpu_model_runner.py | 4 ---- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 2dc7027e01e7a..fc873cdde8e10 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import _MAX_BATCH_SIZE_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -404,7 +404,7 @@ def forward(self, if self.mamba_cache is None: max_batch_size = (VllmConfig.static_pad_for_cudagraph( self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) + else _MAX_BATCH_SIZE_TO_CAPTURE + 2) layers_type = self.config.layers_block_type num_mamba_layers = sum( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 25564e11e6812..997fd0c9a594c 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import _MAX_BATCH_SIZE_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -187,7 +187,7 @@ def forward(self, if self.mamba_cache is None: max_batch_size = (VllmConfig.static_pad_for_cudagraph( self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) + else _MAX_BATCH_SIZE_TO_CAPTURE + 2) self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, self.config.num_hidden_layers, max_batch_size, *self._get_mamba_cache_shape()) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index e6322e095bbb9..9cf25387560da 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -37,10 +37,6 @@ logger = init_logger(__name__) _PAD_SLOT_ID = -1 -_BATCH_SIZE_ALIGNMENT = 8 -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) -] TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") From 690700877f017267f628f701e77eb516fc0df19c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:07:01 -0800 Subject: [PATCH 06/33] fix Signed-off-by: youkaichao --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index a556aa12eb702..feb7b08617877 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2417,7 +2417,7 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): for start, end in zip([0] + _BATCH_SIZES_TO_CAPTURE[:-1], _BATCH_SIZES_TO_CAPTURE): for bs in range(start, end): - bs_to_padded_graph_size[bs] = end + bs_to_padded_graph_size[bs] = start _MAX_BATCH_SIZE_TO_CAPTURE = _BATCH_SIZES_TO_CAPTURE[-1] From 42c93005fafb0591a12d2583255da41e04032e99 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:09:09 -0800 Subject: [PATCH 07/33] fix Signed-off-by: youkaichao --- vllm/config.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index feb7b08617877..ea4383db14562 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2378,7 +2378,10 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): for end, start in zip(self.capture_sizes, self.capture_sizes[1:] + [0]): for bs in range(start, end): - self.bs_to_padded_graph_size[bs] = end + if bs == start: + bs_to_padded_graph_size[bs] = start + else: + bs_to_padded_graph_size[bs] = end self.max_capture_size = self.capture_sizes[0] @@ -2417,7 +2420,10 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): for start, end in zip([0] + _BATCH_SIZES_TO_CAPTURE[:-1], _BATCH_SIZES_TO_CAPTURE): for bs in range(start, end): - bs_to_padded_graph_size[bs] = start + if bs == start: + bs_to_padded_graph_size[bs] = start + else: + bs_to_padded_graph_size[bs] = end _MAX_BATCH_SIZE_TO_CAPTURE = _BATCH_SIZES_TO_CAPTURE[-1] From 4c8adcb7fe9e5af2d636d9cae35bfcaed9ba8c3a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:09:53 -0800 Subject: [PATCH 08/33] fix Signed-off-by: youkaichao --- vllm/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ea4383db14562..e5c7fd88e6a63 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2379,9 +2379,9 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): self.capture_sizes[1:] + [0]): for bs in range(start, end): if bs == start: - bs_to_padded_graph_size[bs] = start + self.bs_to_padded_graph_size[bs] = start else: - bs_to_padded_graph_size[bs] = end + self.bs_to_padded_graph_size[bs] = end self.max_capture_size = self.capture_sizes[0] From 69f8ff2bb426bf93c1a7582aaa5c3f60749d8902 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:12:31 -0800 Subject: [PATCH 09/33] fix Signed-off-by: youkaichao --- vllm/config.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e5c7fd88e6a63..43d93a9a27c01 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2372,6 +2372,7 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) + self.max_capture_size = self.capture_sizes[0] # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = {} @@ -2382,8 +2383,8 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): self.bs_to_padded_graph_size[bs] = start else: self.bs_to_padded_graph_size[bs] = end - - self.max_capture_size = self.capture_sizes[0] + self.bs_to_padded_graph_size[ + self.max_capture_size] = self.max_capture_size """ @@ -2415,6 +2416,7 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) ] +_MAX_BATCH_SIZE_TO_CAPTURE = _BATCH_SIZES_TO_CAPTURE[-1] bs_to_padded_graph_size: Dict[int, int] = {} for start, end in zip([0] + _BATCH_SIZES_TO_CAPTURE[:-1], @@ -2424,8 +2426,8 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): bs_to_padded_graph_size[bs] = start else: bs_to_padded_graph_size[bs] = end - -_MAX_BATCH_SIZE_TO_CAPTURE = _BATCH_SIZES_TO_CAPTURE[-1] +bs_to_padded_graph_size[ + _MAX_BATCH_SIZE_TO_CAPTURE] = _MAX_BATCH_SIZE_TO_CAPTURE @dataclass From 830e34da1cc6efce4f2cd2e34e7ad333826b6df2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:18:04 -0800 Subject: [PATCH 10/33] hide some details in string form Signed-off-by: youkaichao --- vllm/config.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 43d93a9a27c01..9c7c4c14241dc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2287,6 +2287,17 @@ def model_post_init(self, __context: Any) -> None: # Map from layer name to the attention cls static_forward_context: Dict[str, Any] = PrivateAttr + def __str__(self) -> str: + exclude = [ + "static_forward_context", + "enabled_custom_ops", + "disabled_custom_ops", + "compilation_time", + "bs_to_padded_graph_size", + "pass_config", + ] + return self.json(exclude=exclude) + @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" From a8f3ef1013365806b4feca4edee6c33962a161c9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:19:14 -0800 Subject: [PATCH 11/33] hide some details in string form Signed-off-by: youkaichao --- vllm/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 9c7c4c14241dc..4177d683945f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2287,7 +2287,7 @@ def model_post_init(self, __context: Any) -> None: # Map from layer name to the attention cls static_forward_context: Dict[str, Any] = PrivateAttr - def __str__(self) -> str: + def __repr__(self) -> str: exclude = [ "static_forward_context", "enabled_custom_ops", @@ -2298,6 +2298,8 @@ def __str__(self) -> str: ] return self.json(exclude=exclude) + __str__ = __repr__ + @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" From 8edddfdaaab86f0d5b1b7f4221d0f0535ade6ce9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:20:36 -0800 Subject: [PATCH 12/33] hide some details in string form Signed-off-by: youkaichao --- vllm/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4177d683945f6..d0500dd08b80f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2298,8 +2298,6 @@ def __repr__(self) -> str: ] return self.json(exclude=exclude) - __str__ = __repr__ - @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" From 2f7a17db9f3f0edd334089479021625a2facd189 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:21:52 -0800 Subject: [PATCH 13/33] hide some details in string form Signed-off-by: youkaichao --- vllm/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d0500dd08b80f..6a916e5f50623 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2288,16 +2288,18 @@ def model_post_init(self, __context: Any) -> None: static_forward_context: Dict[str, Any] = PrivateAttr def __repr__(self) -> str: - exclude = [ + exclude = { "static_forward_context", "enabled_custom_ops", "disabled_custom_ops", "compilation_time", "bs_to_padded_graph_size", "pass_config", - ] + } return self.json(exclude=exclude) + __str__ = __repr__ + @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" From 7379c67d2f4a594359f63ae6ccb9a5add4b4a02e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:25:11 -0800 Subject: [PATCH 14/33] fix pydantic Signed-off-by: youkaichao --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 6a916e5f50623..e538510055da5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2296,7 +2296,7 @@ def __repr__(self) -> str: "bs_to_padded_graph_size", "pass_config", } - return self.json(exclude=exclude) + return self.model_dump_json(exclude=exclude, exclude_unset=True) __str__ = __repr__ From 1c8067c44991d03684e7ea647bc0654349284c70 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 16:40:54 -0800 Subject: [PATCH 15/33] fix enforce eager Signed-off-by: youkaichao --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index e538510055da5..52638673827b3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2385,7 +2385,8 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) - self.max_capture_size = self.capture_sizes[0] + self.max_capture_size = self.capture_sizes[ + 0] if self.capture_sizes else 0 # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = {} From ec2d4844138685cc700ccc0da39e3ee24c1a47c5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 14:23:14 -0800 Subject: [PATCH 16/33] fix merge Signed-off-by: youkaichao --- vllm/v1/worker/gpu_model_runner.py | 268 +---------------------------- 1 file changed, 1 insertion(+), 267 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 44047d4da365d..67cd1d84d63fa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple import numpy as np import torch @@ -641,269 +641,3 @@ def initialize_kv_cache(self, num_blocks: int) -> None: torch.zeros(kv_cache_shape, dtype=self.kv_cache_dtype, device=self.device)) - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 From 224c6f240b04c24775c3855a1df77b86ad87d422 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 14:25:32 -0800 Subject: [PATCH 17/33] rename to pad_for_cudagraph Signed-off-by: youkaichao --- vllm/config.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/worker/enc_dec_model_runner.py | 2 +- vllm/worker/model_runner.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index eb949c99b0399..9886df456224d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2551,7 +2551,7 @@ class VllmConfig: init=True) # type: ignore instance_id: str = "" - def model_pad_for_cudagraph(self, batch_size: int) -> int: + def pad_for_cudagraph(self, batch_size: int) -> int: """Returns the padded batch size given actual batch size, considering the model's configuration. """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 67cd1d84d63fa..f24942068d1f8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -459,7 +459,7 @@ def execute_model( and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_input_tokens = self.vllm_config.model_pad_for_cudagraph( + num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) else: # Eager mode. diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 016deeb24eac7..bff01320d7927 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -464,7 +464,7 @@ def _prepare_encoder_model_input_tensors( # We will be using CUDA graph replay for this decode. max_len_of_block_table = self.get_max_block_per_batch() batch_size = len(encoder_seq_lens) - graph_batch_size = self.vllm_config.model_pad_for_cudagraph( + graph_batch_size = self.vllm_config.pad_for_cudagraph( batch_size) assert graph_batch_size >= batch_size cuda_graph_pad_size = graph_batch_size - batch_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 287768e0a1715..6ff98a8f1bab2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -802,7 +802,7 @@ def _get_cuda_graph_pad_size(self, max_encoder_seq_len): return -1 - graph_batch_size = self.runner.vllm_config.model_pad_for_cudagraph( + graph_batch_size = self.runner.vllm_config.pad_for_cudagraph( batch_size) assert graph_batch_size >= batch_size return graph_batch_size - batch_size From 957a72ef78872790086e3ead07c08cdfffaa4b4a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 14:28:36 -0800 Subject: [PATCH 18/33] remove mamba import Signed-off-by: youkaichao --- vllm/model_executor/models/jamba.py | 6 +++--- vllm/model_executor/models/mamba.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 9cebbdee617bd..a819e8e78e605 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import _MAX_BATCH_SIZE_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE @@ -434,8 +434,8 @@ def forward(self, **kwargs): if self.mamba_cache is None: max_batch_size = (VllmConfig.static_pad_for_cudagraph( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else _MAX_BATCH_SIZE_TO_CAPTURE + 2) + self.scheduler_config.max_num_seqs) + if self.scheduler_config else 8192 + 2) num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 6c22e8cf35190..e2ddeed015406 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import _MAX_BATCH_SIZE_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -209,8 +209,8 @@ def forward(self, **kwargs): if self.mamba_cache is None: max_batch_size = (VllmConfig.static_pad_for_cudagraph( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else _MAX_BATCH_SIZE_TO_CAPTURE + 2) + self.scheduler_config.max_num_seqs) + if self.scheduler_config else 8192 + 2) num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) From d3c3bdc0176411028a52b8fc4e841c8d2928ca77 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 14:39:58 -0800 Subject: [PATCH 19/33] unify one function Signed-off-by: youkaichao --- .../decoder_only/language/test_jamba.py | 6 +- .../decoder_only/language/test_mamba.py | 6 +- .../test_encoder_decoder_model_runner.py | 4 +- tests/worker/test_model_runner.py | 3 +- vllm/config.py | 64 ++++++------------- vllm/model_executor/models/jamba.py | 10 +-- vllm/model_executor/models/mamba.py | 11 ++-- 7 files changed, 40 insertions(+), 64 deletions(-) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 0797576d123e6..cca5106764e5b 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,7 +1,7 @@ import pytest from tests.utils import multi_gpu_test -from vllm.config import VllmConfig +from vllm.config import VllmConfig, ModelConfig from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal @@ -189,7 +189,9 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.static_pad_for_cudagraph( + vllm_config = VllmConfig( + model_config=ModelConfig(model=model, enforce_eager=False)) + while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index d5d70d38414ef..c71f950551a99 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -5,7 +5,7 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer -from vllm.config import VllmConfig +from vllm.config import VllmConfig, ModelConfig from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal @@ -200,7 +200,9 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.static_pad_for_cudagraph( + vllm_config = VllmConfig( + model_config=ModelConfig(model=model, enforce_eager=False)) + while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 6afbd76e41bed..a6b3cb5759f2b 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -4,7 +4,6 @@ import pytest import torch -from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata @@ -548,7 +547,8 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = VllmConfig.static_pad_for_cudagraph(expanded_batch_size) + graph_batch_size = model_runner.vllm_config.pad_for_cudagraph( + expanded_batch_size) cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 33190f3a66e3f..aabe913c242e1 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,7 +3,6 @@ import pytest import torch -from vllm.config import VllmConfig from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import EngineArgs @@ -177,7 +176,7 @@ def test_prepare_decode_cuda_graph(batch_size): model_input.attn_metadata, model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) - expected_bs = VllmConfig.static_pad_for_cudagraph( + expected_bs = model_runner.vllm_config.pad_for_cudagraph( len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device diff --git a/vllm/config.py b/vllm/config.py index 9886df456224d..75a26d928a18a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2501,28 +2501,6 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): the largest size in `vllm_config.compilation_config.capture_sizes`. """ -_BATCH_SIZE_ALIGNMENT = 8 -# all the token sizes that **can** be captured by cudagraph. -# they can be arbitrarily large. -# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. -# the actual sizes to capture will be determined by the model, -# depending on the model's max_num_seqs. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) -] -_MAX_BATCH_SIZE_TO_CAPTURE = _BATCH_SIZES_TO_CAPTURE[-1] - -bs_to_padded_graph_size: Dict[int, int] = {} -for start, end in zip([0] + _BATCH_SIZES_TO_CAPTURE[:-1], - _BATCH_SIZES_TO_CAPTURE): - for bs in range(start, end): - if bs == start: - bs_to_padded_graph_size[bs] = start - else: - bs_to_padded_graph_size[bs] = end -bs_to_padded_graph_size[ - _MAX_BATCH_SIZE_TO_CAPTURE] = _MAX_BATCH_SIZE_TO_CAPTURE - @dataclass class VllmConfig: @@ -2559,22 +2537,6 @@ def pad_for_cudagraph(self, batch_size: int) -> int: return self.compilation_config.max_capture_size return self.compilation_config.bs_to_padded_graph_size[batch_size] - @staticmethod - def static_pad_for_cudagraph(batch_size: int) -> int: - """ - This function statically pads the batch size to the nearest - number in _BATCH_SIZES_TO_CAPTURE , without considering the - model's configuration. - - if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded - size. if not, it means the padded size is larger than the largest size - in _BATCH_SIZES_TO_CAPTURE, return the largest size in - _BATCH_SIZES_TO_CAPTURE. - """ - if batch_size > _MAX_BATCH_SIZE_TO_CAPTURE: - return _MAX_BATCH_SIZE_TO_CAPTURE - return bs_to_padded_graph_size[batch_size] - @staticmethod def _get_quantization_config( model_config: ModelConfig, @@ -2668,17 +2630,29 @@ def __post_init__(self): self.compilation_config.level = CompilationLevel.PIECEWISE if not envs.VLLM_USE_V1: + batch_size_capture_list = [] max_batchsize_to_capture = 0 if self.scheduler_config is not None and \ self.model_config is not None and \ not self.model_config.enforce_eager: - max_batchsize_to_capture = \ - self.static_pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - batch_size_capture_list = [ - size for size in _BATCH_SIZES_TO_CAPTURE - if size <= max_batchsize_to_capture - ] + # all the token sizes that **can** be captured by cudagraph. + # they can be arbitrarily large. + # currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. + # the actual sizes to capture will be determined by the model, + # depending on the model's max_num_seqs. + possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + larger_sizes = [ + x for x in possible_sizes + if x >= self.scheduler_config.max_num_seqs + ] + if larger_sizes: + max_batchsize_to_capture = larger_sizes[0] + else: + max_batchsize_to_capture = possible_sizes[-1] + batch_size_capture_list = [ + size for size in possible_sizes + if size <= max_batchsize_to_capture + ] else: batch_size_capture_list = [] if self.model_config is not None and \ diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index a819e8e78e605..1ac6146968374 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -420,6 +420,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.max_batch_size = (vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + if self.scheduler_config else 8192 + 2) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -433,15 +436,12 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.static_pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - if self.scheduler_config else 8192 + 2) num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, - *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) ( mamba_cache_tensors, state_indices_tensor, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index e2ddeed015406..a4ef5e1f68a32 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -195,6 +195,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) + self.max_batch_size = (vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + if self.scheduler_config else 8192 + 2) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) @@ -208,15 +211,11 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.static_pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - if self.scheduler_config else 8192 + 2) - num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, - *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) ( mamba_cache_tensors, From f952d1891585008c6c4953ef9de02244317ac305 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 14:47:41 -0800 Subject: [PATCH 20/33] fix Signed-off-by: youkaichao --- .../decoder_only/language/test_jamba.py | 2 +- .../decoder_only/language/test_mamba.py | 2 +- vllm/config.py | 84 ++++++++++--------- 3 files changed, 46 insertions(+), 42 deletions(-) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index cca5106764e5b..bb3ac97f20b3d 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,7 +1,7 @@ import pytest from tests.utils import multi_gpu_test -from vllm.config import VllmConfig, ModelConfig +from vllm.config import ModelConfig, VllmConfig from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index c71f950551a99..12d93cecc52ff 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -5,7 +5,7 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer -from vllm.config import VllmConfig, ModelConfig +from vllm.config import ModelConfig, VllmConfig from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal diff --git a/vllm/config.py b/vllm/config.py index 75a26d928a18a..8b10f4694b90d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2481,27 +2481,6 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): self.max_capture_size] = self.max_capture_size -""" -cudagraph batchsize padding logic: - -In the default case, `_BATCH_SIZES_TO_CAPTURE` is a list of all possible -batch sizes that cudagraph will capture. We pre-build a mapping from batch size -to padded graph size, so that we can quickly find the padded graph size for a -given batch size. Depending on the model's configuration, like `max_num_seqs`, -the candidate batch sizes to capture cudagraph will shrink to the subset of -`_BATCH_SIZES_TO_CAPTURE` that is less than or equal to `max_num_seqs`. - -However, if users specify the cudagraph capture sizes through compilation -config, we will use the specified sizes instead. - -In the end, `vllm_config.compilation_config.capture_sizes` will be the final -sizes to capture cudagraph (in descending order), and -`vllm_config.compilation_config.bs_to_padded_graph_size` will be the mapping -from batch size to padded graph size, if the batch size is less than or equal to -the largest size in `vllm_config.compilation_config.capture_sizes`. -""" - - @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This @@ -2629,6 +2608,50 @@ def __post_init__(self): self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE + self._set_cudagraph_sizes() + + if self.cache_config is not None and \ + self.cache_config.cpu_offload_gb > 0 and \ + self.compilation_config.level != CompilationLevel.NO_COMPILATION: + logger.warning( + "CPU offload is not supported with `torch.compile` yet." + " Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if self.lora_config is not None and self.compilation_config.level !=\ + CompilationLevel.NO_COMPILATION: + logger.warning("LoRA is not supported with `torch.compile` yet. " + "Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + current_platform.check_and_update_config(self) + + if not self.instance_id: + self.instance_id = random_uuid()[:5] + + def _set_cudagraph_sizes(self): + """ + cudagraph batchsize padding logic: + + In the default case, `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible + batch sizes that cudagraph will capture. Depending on the engine's configuration of `max_num_seqs`, + the candidate batch sizes to capture cudagraph will shrink to the subset which just cover the range of + `[1, max_num_seqs]`. + + However, if users specify the cudagraph capture sizes through compilation + config, we will use the specified sizes instead. + + In the end, `vllm_config.compilation_config.capture_sizes` will be the final + sizes to capture cudagraph (in descending order) + + We pre-build a mapping from batch size + to padded graph size, so that we can quickly find the padded graph size for a + given batch size. `vllm_config.compilation_config.bs_to_padded_graph_size` will be the mapping + from batch size to padded graph size, if the batch size is less than or equal to + the largest size in `vllm_config.compilation_config.capture_sizes`. + """ # noqa + + # calculate the default `batch_size_capture_list` if not envs.VLLM_USE_V1: batch_size_capture_list = [] max_batchsize_to_capture = 0 @@ -2663,25 +2686,6 @@ def __post_init__(self): self.compilation_config.init_with_cudagraph_sizes( batch_size_capture_list) - if self.cache_config is not None and \ - self.cache_config.cpu_offload_gb > 0 and \ - self.compilation_config.level != CompilationLevel.NO_COMPILATION: - logger.warning( - "CPU offload is not supported with `torch.compile` yet." - " Disabling `torch.compile`.") - self.compilation_config.level = CompilationLevel.NO_COMPILATION - - if self.lora_config is not None and self.compilation_config.level !=\ - CompilationLevel.NO_COMPILATION: - logger.warning("LoRA is not supported with `torch.compile` yet. " - "Disabling `torch.compile`.") - self.compilation_config.level = CompilationLevel.NO_COMPILATION - - current_platform.check_and_update_config(self) - - if not self.instance_id: - self.instance_id = random_uuid()[:5] - def __str__(self): return ( f"model={self.model_config.model!r}," From e92559cd7a5ffbe7fc596d8787203e3a155571b2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 14:53:03 -0800 Subject: [PATCH 21/33] fix Signed-off-by: youkaichao --- tests/models/decoder_only/language/test_jamba.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index bb3ac97f20b3d..057b04349e8b7 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,7 +1,7 @@ import pytest from tests.utils import multi_gpu_test -from vllm.config import ModelConfig, VllmConfig +from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal @@ -189,8 +189,7 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - vllm_config = VllmConfig( - model_config=ModelConfig(model=model, enforce_eager=False)) + vllm_config = EngineArgs(model=model).create_engine_config() while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) From a02b8f186c4860b5f877123214b21eeec4f1f77f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 14:58:14 -0800 Subject: [PATCH 22/33] fix mamba Signed-off-by: youkaichao --- tests/models/decoder_only/language/test_mamba.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 12d93cecc52ff..06739e8f02253 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -5,7 +5,7 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer -from vllm.config import ModelConfig, VllmConfig +from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal @@ -200,8 +200,7 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - vllm_config = VllmConfig( - model_config=ModelConfig(model=model, enforce_eager=False)) + vllm_config = EngineArgs(model=model).create_engine_config() while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) From 24b548a0d71d5c5b9f18529680cb6b3e39d54bc9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 15:04:21 -0800 Subject: [PATCH 23/33] use list Signed-off-by: youkaichao --- vllm/config.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 8b10f4694b90d..03fb5195ee19e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2355,7 +2355,9 @@ def model_post_init(self, __context: Any) -> None: compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr max_capture_size: int = PrivateAttr - bs_to_padded_graph_size: Dict[int, int] = PrivateAttr + # optimization: Dict[int, int] can be optimized to List[int] + # if we know all keys are in a range [0, max_capture_size] + bs_to_padded_graph_size: List[int] = PrivateAttr # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = PrivateAttr @@ -2469,7 +2471,9 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): 0] if self.capture_sizes else 0 # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = {} + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_capture_size + 1) + ] for end, start in zip(self.capture_sizes, self.capture_sizes[1:] + [0]): for bs in range(start, end): From 4ba82c006abcbaabfc745e6d56e144e7bffb3c77 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 15:05:52 -0800 Subject: [PATCH 24/33] comment Signed-off-by: youkaichao --- vllm/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 03fb5195ee19e..904222f6f9a44 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2355,8 +2355,10 @@ def model_post_init(self, __context: Any) -> None: compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr max_capture_size: int = PrivateAttr - # optimization: Dict[int, int] can be optimized to List[int] - # if we know all keys are in a range [0, max_capture_size] + # optimization: + # Intuitively, bs_to_padded_graph_size should be Dict[int, int]. + # since we know all keys are in a range [0, max_capture_size], + # we can optimize it to List[int] for better lookup performance. bs_to_padded_graph_size: List[int] = PrivateAttr # keep track of enabled and disabled custom ops From 859e3ee8264dfa28b2a01b19fb64bcf84667162b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 15:07:30 -0800 Subject: [PATCH 25/33] remove comments Signed-off-by: youkaichao --- vllm/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 904222f6f9a44..af2968390505f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2515,9 +2515,6 @@ class VllmConfig: instance_id: str = "" def pad_for_cudagraph(self, batch_size: int) -> int: - """Returns the padded batch size given actual batch size, - considering the model's configuration. - """ if batch_size > self.compilation_config.max_capture_size: return self.compilation_config.max_capture_size return self.compilation_config.bs_to_padded_graph_size[batch_size] From ce0cff997c22eaf65df879c8ccad77cc96fd5781 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 15:11:09 -0800 Subject: [PATCH 26/33] comments Signed-off-by: youkaichao --- vllm/config.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index af2968390505f..60d159efc0918 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2636,22 +2636,25 @@ def _set_cudagraph_sizes(self): """ cudagraph batchsize padding logic: - In the default case, `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible - batch sizes that cudagraph will capture. Depending on the engine's configuration of `max_num_seqs`, + `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible + batch sizes that cudagraph will capture. + + Depending on the engine's configuration of `max_num_seqs`, the candidate batch sizes to capture cudagraph will shrink to the subset which just cover the range of - `[1, max_num_seqs]`. - + `[1, max_num_seqs]`. In the common case, `max_num_seqs` is 256, + and the cudagraph batch sizes will be `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. + However, if users specify the cudagraph capture sizes through compilation config, we will use the specified sizes instead. In the end, `vllm_config.compilation_config.capture_sizes` will be the final sizes to capture cudagraph (in descending order) - We pre-build a mapping from batch size - to padded graph size, so that we can quickly find the padded graph size for a - given batch size. `vllm_config.compilation_config.bs_to_padded_graph_size` will be the mapping - from batch size to padded graph size, if the batch size is less than or equal to - the largest size in `vllm_config.compilation_config.capture_sizes`. + During runtime, if batchsize is larger than `vllm_config.compilation_config.capture_sizes`, + no cudagraph will be used. + if the batch size is no larger than `vllm_config.compilation_config.capture_sizes`, we can + quickly find the padded graph size for a + given batch size by looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. """ # noqa # calculate the default `batch_size_capture_list` From 3f750a2e48b235c9a577d08980cccfaad3f384ce Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 15:13:09 -0800 Subject: [PATCH 27/33] comments Signed-off-by: youkaichao --- vllm/config.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 60d159efc0918..81c39c321f76b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2639,23 +2639,25 @@ def _set_cudagraph_sizes(self): `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible batch sizes that cudagraph will capture. - Depending on the engine's configuration of `max_num_seqs`, - the candidate batch sizes to capture cudagraph will shrink to the subset which just cover the range of - `[1, max_num_seqs]`. In the common case, `max_num_seqs` is 256, - and the cudagraph batch sizes will be `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. + Depending on the engine's configuration of `max_num_seqs`, the candidate + batch sizes to capture cudagraph will shrink to the subset which just + cover the range of `[1, max_num_seqs]`. In the common case, `max_num_seqs` + is 256, and the cudagraph batch sizes will be `[1, 2, 4, 8, 16, 24, 32, 40, + ..., 256]`. However, if users specify the cudagraph capture sizes through compilation config, we will use the specified sizes instead. In the end, `vllm_config.compilation_config.capture_sizes` will be the final - sizes to capture cudagraph (in descending order) - - During runtime, if batchsize is larger than `vllm_config.compilation_config.capture_sizes`, - no cudagraph will be used. - if the batch size is no larger than `vllm_config.compilation_config.capture_sizes`, we can - quickly find the padded graph size for a - given batch size by looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. - """ # noqa + sizes to capture cudagraph (in descending order). + + During runtime, if batchsize is larger than + `vllm_config.compilation_config.capture_sizes`, no cudagraph will be used. + If the batch size is no larger than + `vllm_config.compilation_config.capture_sizes`, + we can quickly find the padded graph size for a given batch size by looking + up `vllm_config.compilation_config.bs_to_padded_graph_size`. + """ # calculate the default `batch_size_capture_list` if not envs.VLLM_USE_V1: From 56512ba8f5729319cd4f23e3fce24f19c75ed798 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 15:16:15 -0800 Subject: [PATCH 28/33] comments Signed-off-by: youkaichao --- vllm/config.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 81c39c321f76b..e77a50570aac7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2639,24 +2639,25 @@ def _set_cudagraph_sizes(self): `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible batch sizes that cudagraph will capture. - Depending on the engine's configuration of `max_num_seqs`, the candidate - batch sizes to capture cudagraph will shrink to the subset which just - cover the range of `[1, max_num_seqs]`. In the common case, `max_num_seqs` - is 256, and the cudagraph batch sizes will be `[1, 2, 4, 8, 16, 24, 32, 40, - ..., 256]`. + Depending on the engine's configuration of `max_num_seqs`, the + candidate batch sizes to capture cudagraph will shrink to the subset + which just cover the range of `[1, max_num_seqs]`. In the common case, + `max_num_seqs` is 256, and the cudagraph batch sizes will be + `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. - However, if users specify the cudagraph capture sizes through compilation - config, we will use the specified sizes instead. + However, if users specify the cudagraph capture sizes through + compilation config, we will use the specified sizes instead. - In the end, `vllm_config.compilation_config.capture_sizes` will be the final - sizes to capture cudagraph (in descending order). + In the end, `vllm_config.compilation_config.capture_sizes` will be the + final sizes to capture cudagraph (in descending order). During runtime, if batchsize is larger than - `vllm_config.compilation_config.capture_sizes`, no cudagraph will be used. + `vllm_config.compilation_config.capture_sizes`, + no cudagraph will be used. If the batch size is no larger than `vllm_config.compilation_config.capture_sizes`, - we can quickly find the padded graph size for a given batch size by looking - up `vllm_config.compilation_config.bs_to_padded_graph_size`. + we can quickly find the padded graph size for a given batch size by + looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. """ # calculate the default `batch_size_capture_list` From 5d6928a782a2b7a3acc937748bfc3fdc1824b3fb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 15:18:26 -0800 Subject: [PATCH 29/33] fix Signed-off-by: youkaichao --- vllm/config.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e77a50570aac7..4c868b6df659a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2667,12 +2667,10 @@ def _set_cudagraph_sizes(self): if self.scheduler_config is not None and \ self.model_config is not None and \ not self.model_config.enforce_eager: - # all the token sizes that **can** be captured by cudagraph. - # they can be arbitrarily large. - # currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. - # the actual sizes to capture will be determined by the model, - # depending on the model's max_num_seqs. + possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + # find the minimum size that is larger than max_num_seqs, + # which then becomes the max_batchsize_to_capture larger_sizes = [ x for x in possible_sizes if x >= self.scheduler_config.max_num_seqs @@ -2681,6 +2679,9 @@ def _set_cudagraph_sizes(self): max_batchsize_to_capture = larger_sizes[0] else: max_batchsize_to_capture = possible_sizes[-1] + + # filter out the sizes that are + # larger than max_batchsize_to_capture batch_size_capture_list = [ size for size in possible_sizes if size <= max_batchsize_to_capture From f4a7a77dabcbc132dd66e5c75fac0ed0244c0207 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 17:05:05 -0800 Subject: [PATCH 30/33] remove if Signed-off-by: youkaichao --- vllm/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4c868b6df659a..12ed80c366e43 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2515,8 +2515,10 @@ class VllmConfig: instance_id: str = "" def pad_for_cudagraph(self, batch_size: int) -> int: - if batch_size > self.compilation_config.max_capture_size: - return self.compilation_config.max_capture_size + # if batch_size > self.compilation_config.max_capture_size, + # it should raise an IndexError. + # the caller should make sure the batch_size is within the range, + # i.e., batch_size <= self.compilation_config.max_capture_size return self.compilation_config.bs_to_padded_graph_size[batch_size] @staticmethod From 07d77a17dc4b5eaa532bd46dc4bd4e1aa94276f9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 20:37:09 -0800 Subject: [PATCH 31/33] fix jamba Signed-off-by: youkaichao --- vllm/model_executor/models/jamba.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 1ac6146968374..60733c3640533 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -420,9 +420,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self.max_batch_size = (vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - if self.scheduler_config else 8192 + 2) + if self.scheduler_config is not None: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + else: + self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) From c45439157cef82d01656e72c9bdec01b36c98580 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 20:37:41 -0800 Subject: [PATCH 32/33] fix mamba Signed-off-by: youkaichao --- vllm/model_executor/models/mamba.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index a4ef5e1f68a32..62d8af6aeb2c1 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -195,9 +195,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) - self.max_batch_size = (vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - if self.scheduler_config else 8192 + 2) + if self.scheduler_config is not None: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + else: + self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) From 74f69b66ffa0d8ef0b9de851f05060c8d9774f41 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Dec 2024 20:53:05 -0800 Subject: [PATCH 33/33] fix both Signed-off-by: youkaichao --- vllm/model_executor/models/jamba.py | 3 ++- vllm/model_executor/models/mamba.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 60733c3640533..831db2ae52d74 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -420,7 +420,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - if self.scheduler_config is not None: + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: if self.scheduler_config.max_num_seqs > \ vllm_config.compilation_config.max_capture_size: self.max_batch_size = \ diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 62d8af6aeb2c1..06c8d9723cd01 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -195,7 +195,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) - if self.scheduler_config is not None: + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: if self.scheduler_config.max_num_seqs > \ vllm_config.compilation_config.max_capture_size: self.max_batch_size = \