diff --git a/vllm/config.py b/vllm/config.py index b3329f1c449ff..27ad17d437a49 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -928,6 +928,9 @@ class SchedulerConfig: a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single iteration. + max_num_prefill_seqs: Maximum number of prefill sequences to be + processed in a single iteration. Used only with padding-aware + scheduling. max_model_len: Maximum length of a sequence (including prompt and generated text). use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. @@ -951,11 +954,14 @@ class SchedulerConfig: when SPMD worker architecture is enabled. I.e., VLLM_USE_RAY_SPMD_WORKER=1 policy: The scheduling policy to use. "fcfs" (default) or "priority". + use_padding_aware_scheduling: If True, scheduler will consider padded + tokens in prefill. """ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, + max_num_prefill_seqs: Optional[int], max_model_len: int, use_v2_block_manager: bool = True, num_lookahead_slots: int = 0, @@ -967,7 +973,8 @@ def __init__(self, num_scheduler_steps: int = 1, multi_step_stream_outputs: bool = False, send_delta_data: bool = False, - policy: str = "fcfs") -> None: + policy: str = "fcfs", + use_padding_aware_scheduling=False) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: if num_scheduler_steps > 1: @@ -1006,6 +1013,7 @@ def __init__(self, self.max_num_batched_tokens) self.max_num_seqs = max_num_seqs + self.max_num_prefill_seqs = max_num_prefill_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots @@ -1017,6 +1025,7 @@ def __init__(self, self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data self.policy = policy + self.use_padding_aware_scheduling = use_padding_aware_scheduling self._verify_args() def _verify_args(self) -> None: @@ -1047,6 +1056,13 @@ def _verify_args(self) -> None: "num_scheduler_steps " f"({self.num_scheduler_steps}) must be greater than or " "equal to 1.") + if self.max_num_prefill_seqs is not None \ + and not self.use_padding_aware_scheduling: + raise ValueError("max_num_prefill_seqs can be only " + "used with padding-aware-scheduling. ") + if self.use_padding_aware_scheduling and self.chunked_prefill_enabled: + raise ValueError("Padding-aware scheduling currently " + "does not work with chunked prefill ") @property def is_multi_step(self) -> bool: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f3a5016d0e62a..c62b905766cad 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, @@ -101,6 +102,94 @@ def num_curr_seqs(self): return self._num_curr_seqs +@dataclass +class PaddingAwareSchedulingBudget(SchedulingBudget): + max_num_prefill_seqs: Optional[int] = None + _prefill_request_ids_max_seq_lens: Dict[str, + int] = field(default_factory=dict) + _max_seq_len: int = 0 + _num_curr_prefill_seqs: int = 0 + + def _generic_padding_fn(self, batch_size, max_seq_len) -> int: + return batch_size * max_seq_len + + def _hpu_padding_fn(self, batch_size, max_seq_len): + from vllm.worker.hpu_model_runner import (HPUBucketingGlobalState, + find_bucket) + padded_bs = batch_size + padded_seq = max_seq_len + + hpu_bucketing_global_state = HPUBucketingGlobalState() + + bs_cfg = hpu_bucketing_global_state.prompt_bs_bucket_cfg + if bs_cfg is not None: + padded_bs = find_bucket(batch_size, bs_cfg) + else: + logger.warning( + "prompt_bs_bucket_cfg was not set! Using unpadded batch size.") + seq_cfg = hpu_bucketing_global_state.prompt_seq_bucket_cfg + if seq_cfg is not None: + padded_seq = find_bucket(max_seq_len, seq_cfg) + else: + logger.warning("prompt_seq_bucket_cfg was not set! " + "Using unpadded sequence length.") + return padded_bs * padded_seq + + def _padding_fn_selector(self): + if current_platform.is_hpu(): + return self._hpu_padding_fn + return self._generic_padding_fn + + def _maybe_update_max_seq_len(self, + new_seq_max_seq_len: Optional[int] = None): + if new_seq_max_seq_len is not None \ + and new_seq_max_seq_len > self._max_seq_len: + self._max_seq_len = new_seq_max_seq_len + return + self._max_seq_len = max( + self._prefill_request_ids_max_seq_lens.values()) + + def add_prefill_seqs(self, req_id, num_curr_prefill_seqs, max_seq_len): + self._prefill_request_ids_max_seq_lens[req_id] = max_seq_len + self._num_curr_prefill_seqs += num_curr_prefill_seqs + self._maybe_update_max_seq_len(max_seq_len) + + def subtract_prefill_seqs(self, req_id, num_curr_prefill_seqs): + if req_id in self._prefill_request_ids_max_seq_lens: + popped_seq_len = self._prefill_request_ids_max_seq_lens.pop(req_id) + self._num_curr_prefill_seqs -= num_curr_prefill_seqs + if popped_seq_len == self._max_seq_len: + self._maybe_update_max_seq_len() + + def can_schedule(self, + *args, + num_new_tokens: int, + num_new_seqs: int, + is_prefill: bool = False, + max_seq_len: int = 0): + can_parent_schedule = super().can_schedule( + *args, num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs) + if not can_parent_schedule or not is_prefill: + return can_parent_schedule + new_batch_size = self._num_curr_prefill_seqs + num_new_seqs + new_max_seq_len = max(max(self._max_seq_len, max_seq_len), 1) + padding_fn = self._padding_fn_selector() + num_new_padded_tokens = padding_fn(new_batch_size, new_max_seq_len) + result = num_new_padded_tokens <= self.token_budget + if self.max_num_prefill_seqs is not None and result: + result = self._num_curr_prefill_seqs + num_new_seqs \ + <= self.max_num_prefill_seqs + return result + + @property + def max_seq_len(self): + return self._max_seq_len + + @property + def num_curr_prefill_seqs(self): + return self._num_curr_prefill_seqs + + @dataclass class ScheduledSequenceGroup: # A sequence group that's scheduled. @@ -937,9 +1026,18 @@ def _schedule_prefills( continue num_new_seqs = seq_group.get_max_num_running_seqs() + max_prefill_seq_len = None + can_schedule_kwargs = { + 'num_new_tokens': num_new_tokens, + 'num_new_seqs': num_new_seqs + } + if self.scheduler_config.use_padding_aware_scheduling: + max_prefill_seq_len = max( + [seq.get_num_new_tokens() for seq in seq_group.get_seqs()]) + can_schedule_kwargs['is_prefill'] = True + can_schedule_kwargs['max_seq_len'] = max_prefill_seq_len if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + or not budget.can_schedule(**can_schedule_kwargs)): break # Can schedule this request. @@ -970,6 +1068,10 @@ def _schedule_prefills( token_chunk_size=num_new_tokens)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_seqs(seq_group.request_id, num_new_seqs) + if self.scheduler_config.use_padding_aware_scheduling: + assert isinstance(budget, PaddingAwareSchedulingBudget) + budget.add_prefill_seqs(seq_group.request_id, num_new_seqs, + max_prefill_seq_len) # Queue requests that couldn't be scheduled. waiting_queue.extendleft(leftover_waiting_sequences) @@ -991,10 +1093,18 @@ def _schedule_default(self) -> SchedulerOutputs: be swapped or preempted. """ # Include running requests to the budget. - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) + budget: SchedulingBudget + if self.scheduler_config.use_padding_aware_scheduling: + budget = PaddingAwareSchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + max_num_prefill_seqs=self.scheduler_config.max_num_prefill_seqs + ) + else: + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) # Make sure we include num running seqs before scheduling prefill, # so that we don't schedule beyond max_num_seqs for prefill. for seq_group in self.running: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 898a8d4c6eeaa..2ac556ad26459 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -113,11 +113,13 @@ class EngineArgs: enable_prefix_caching: bool = False disable_sliding_window: bool = False use_v2_block_manager: bool = True + use_padding_aware_scheduling: bool = False swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 + max_num_prefill_seqs: Optional[int] = None max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False revision: Optional[str] = None @@ -387,6 +389,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', help='Use BlockSpaceMangerV2. By default this is set to True. ' 'Set to False to use BlockSpaceManagerV1') + parser.add_argument( + '--use-padding-aware-scheduling', + default=EngineArgs.use_padding_aware_scheduling, + action='store_true', + help=('Use padding-aware scheduling. If True, the scheduler ' + 'will consider padded tokens in prefill. ' + 'By default this is set to False. ')) parser.add_argument( '--num-lookahead-slots', type=int, @@ -441,6 +450,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.max_num_seqs, help='Maximum number of sequences per iteration.') + parser.add_argument( + '--max-num-prefill-seqs', + type=int, + default=EngineArgs.max_num_prefill_seqs, + help=('Maximum number of prefill sequences per ' + 'iteration. Can be used only with padding-aware ' + 'scheduling. Must be <= max_num_seqs.')) parser.add_argument( '--max-logprobs', type=int, @@ -1033,6 +1049,7 @@ def create_engine_config(self) -> EngineConfig: scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, + max_num_prefill_seqs=self.max_num_prefill_seqs, max_model_len=model_config.max_model_len, use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=num_lookahead_slots, @@ -1046,7 +1063,7 @@ def create_engine_config(self) -> EngineConfig: send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, - ) + use_padding_aware_scheduling=self.use_padding_aware_scheduling) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 99dc326612588..b1891993c3442 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -13,6 +13,7 @@ import os import time from array import array +from dataclasses import dataclass, field from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -64,20 +65,134 @@ LORA_WARMUP_RANK = 8 -def subtuple(obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None): - if obj is None: - return None - if to_override is None: - to_override = {} - fields = set(to_copy) | set(to_override.keys()) - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} - if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, - ' '.join(fields)) - return _TYPE_CACHE[typename](**values) +class Singleton(type): + _instances: Dict[type, object] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, + cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class HPUBucketingGlobalState(metaclass=Singleton): + prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_buckets: List[Tuple[int, int]] = field(init=False) + decode_buckets: List[Tuple[int, int]] = field(init=False) + + +class HPUBucketingContext(metaclass=Singleton): + global_state = HPUBucketingGlobalState() + + def __init__(self, max_num_seqs, max_num_prefill_seqs, block_size, + max_num_batched_tokens): + self.max_num_seqs = max_num_seqs + self.max_num_prefill_seqs = max_num_prefill_seqs + self.block_size = block_size + self.max_num_batched_tokens = max_num_batched_tokens + self._setup_buckets() + + def _setup_buckets(self) -> None: + align_bs = lambda x: min(self.max_num_seqs, x) + #FIXME: The default values should be max_model_len + max_prompt_seq = 1024 + max_decode_seq = 2048 + self.global_state.prompt_bs_bucket_cfg = read_bucket_settings( + 'prompt', + 'bs', + min=1, + step=align_bs(32), + max=self.max_num_prefill_seqs) + self.global_state.decode_bs_bucket_cfg = read_bucket_settings( + 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) + self.global_state.prompt_seq_bucket_cfg = \ + read_bucket_settings( + 'prompt', + 'seq', + min=self.block_size, + step=self.block_size, + max=max_prompt_seq) + self.global_state.decode_block_bucket_cfg = \ + read_bucket_settings( + 'decode', + 'block', + min=self.block_size, + step=self.block_size, + max=max(self.block_size, + self.max_num_seqs * max_decode_seq // self.block_size)) + + msg = ("Prompt bucket config (min, step, max_warmup) " + f"bs:{self.global_state.prompt_bs_bucket_cfg}, " + f"seq:{self.global_state.prompt_seq_bucket_cfg}") + logger.info(msg) + + msg = ("Decode bucket config (min, step, max_warmup) " + f"bs:{self.global_state.decode_bs_bucket_cfg}, " + f"block:{self.global_state.decode_block_bucket_cfg}") + logger.info(msg) + + def generate_prompt_buckets(self): + self.global_state.prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( + self.global_state.prompt_bs_bucket_cfg, + self.global_state.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + + msg = (f"Generated {len(self.global_state.prompt_buckets)} " + f"prompt buckets [bs, seq]: \ + {list(sorted(self.global_state.prompt_buckets))}") + logger.info(msg) + + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + + def generate_decode_buckets(self, max_blocks): + self.global_state.decode_buckets = generate_decode_buckets( + self.global_state.decode_bs_bucket_cfg, + self.global_state.decode_block_bucket_cfg, max_blocks) + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", + len(self.global_state.decode_buckets), + list(sorted(self.global_state.decode_buckets))) + + def get_padded_prompt_batch_size(self, batch_size): + return find_bucket(batch_size, self.global_state.prompt_bs_bucket_cfg) + + def get_padded_decode_batch_size(self, batch_size): + return find_bucket(batch_size, self.global_state.decode_bs_bucket_cfg) + + def get_padded_prompt_seq_len(self, seq_len): + return find_bucket(seq_len, self.global_state.prompt_seq_bucket_cfg) + + def get_padded_decode_num_blocks(self, num_blocks): + return find_bucket(num_blocks, + self.global_state.decode_block_bucket_cfg) + + def get_padded_batch_size(self, batch_size, is_prompt): + if is_prompt: + return self.get_padded_prompt_batch_size(batch_size) + return self.get_padded_decode_batch_size(batch_size) + + def get_padded_seq_or_block(self, seq_or_block, is_prompt): + if is_prompt: + return self.get_padded_prompt_seq_len(seq_or_block) + return self.get_padded_decode_num_blocks(seq_or_block) + + @property + def prompt_buckets(self): + return self.global_state.prompt_buckets + + @property + def decode_buckets(self): + return self.global_state.decode_buckets def read_bucket_settings(phase: str, dim: str, **defaults): @@ -208,6 +323,22 @@ def find_bucket(value: int, config: Tuple[int, int, int]): return max(bmin, min(next_step, next_pow)) +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) + return _TYPE_CACHE[typename](**values) + + def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -532,6 +663,9 @@ def __init__( self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs + self.max_num_prefill_seqs = self.scheduler_config.max_num_prefill_seqs \ + if self.scheduler_config.max_num_prefill_seqs is not None \ + else self.max_num_seqs self.max_model_len = self.scheduler_config.max_model_len self.max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens @@ -560,7 +694,12 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None - self._setup_buckets() + self.bucketing_ctx = HPUBucketingContext( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_prefill_seqs, self.block_size, + self.max_num_batched_tokens) + self.graphed_buckets: Set[Any] = set() + self._set_gc_threshold() def _set_gc_threshold(self) -> None: @@ -669,47 +808,6 @@ def _use_graphs(self, batch_size, seq_len, is_prompt): def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens - def _setup_buckets(self) -> None: - align_bs = lambda x: min(self.max_num_seqs, x) - max_bucket_cfg = 64 - #FIXME: The default values should be max_model_len - max_prompt_seq = 1024 - max_decode_seq = 2048 - self.prompt_bs_bucket_cfg = read_bucket_settings( - 'prompt', - 'bs', - min=1, - step=align_bs(32), - max=align_bs(max_bucket_cfg)) - self.decode_bs_bucket_cfg = read_bucket_settings('decode', - 'bs', - min=1, - step=align_bs(32), - max=self.max_num_seqs) - self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) - self.decode_block_bucket_cfg = read_bucket_settings( - 'decode', - 'block', - min=self.block_size, - step=self.block_size, - max=max(self.block_size, - self.max_num_seqs * max_decode_seq // self.block_size)) - self.graphed_buckets: Set[Any] = set() - - msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.prompt_bs_bucket_cfg}, " - f"seq:{self.prompt_seq_bucket_cfg}") - logger.info(msg) - - msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.decode_bs_bucket_cfg}, " - f"block:{self.decode_block_bucket_cfg}") - logger.info(msg) - def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -825,7 +923,7 @@ def _prepare_prompt( assert max_query_len > 0 max_prompt_len = max( - find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), + self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)), self.block_size) lora_ids: List[int] = [] @@ -983,8 +1081,8 @@ def _prepare_decode( for b_u, lb in zip(blocks_used, last_block)] block_usage = list(itertools.chain(*block_usage)) - block_bucket_size = find_bucket(len(block_list), - self.decode_block_bucket_cfg) + block_bucket_size = \ + self.bucketing_ctx.get_padded_decode_num_blocks(len(block_list)) block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) block_mapping = pad_list(block_mapping, block_bucket_size, 0) block_usage = pad_list(block_usage, block_bucket_size, 0) @@ -1052,9 +1150,8 @@ def prepare_input_tensors( self.profiler.start('internal', base_event_name) real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.prompt_bs_bucket_cfg if is_prompt else \ - self.decode_bs_bucket_cfg - batch_size_padded = find_bucket(real_batch_size, bucket_cfg) + batch_size_padded = self.bucketing_ctx.get_padded_batch_size( + real_batch_size, is_prompt) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() if batch_size_padding > 0: @@ -1258,9 +1355,11 @@ def create_dummy_seq_group_metadata(self, def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = min(self.prompt_seq_bucket_cfg[-1], - self.max_num_batched_tokens // max_batch_size) + max_batch_size = self.bucketing_ctx.global_state.prompt_bs_bucket_cfg[ + -1] + max_seq_len = min( + self.bucketing_ctx.global_state.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) @@ -1473,34 +1572,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: return self.profiler.start('internal', 'warmup') max_blocks = kv_caches[0][0].size(0) - - self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - msg = ( - f"Generated {len(self.prompt_buckets)} " - f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) - - self.decode_buckets = generate_decode_buckets( - self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, - max_blocks) - logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.decode_buckets), - list(sorted(self.decode_buckets))) + self.bucketing_ctx.generate_prompt_buckets() + self.bucketing_ctx.generate_decode_buckets(max_blocks) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: - cache_size_limit = len(self.prompt_buckets) + len( - self.decode_buckets) + 1 + cache_size_limit = len(self.bucketing_ctx.prompt_buckets) + len( + self.bucketing_ctx.decode_buckets) + 1 torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between @@ -1527,8 +1604,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) - self.warmup_all_buckets(self.decode_buckets, False, kv_caches) + self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, + kv_caches) + self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, + kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1558,12 +1637,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, kv_caches, - prompt_available_memory) + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, - decode_available_memory) + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1572,8 +1651,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, - kv_caches, + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1584,14 +1663,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not decode_captured_all \ and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) - self.log_graph_warmup_summary(self.prompt_buckets, True, - mem_post_prompt) - self.log_graph_warmup_summary(self.decode_buckets, False, - mem_post_decode) + self.log_graph_warmup_summary( + self.bucketing_ctx.prompt_buckets, True, mem_post_prompt) + self.log_graph_warmup_summary( + self.bucketing_ctx.decode_buckets, False, mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage()