Skip to content

Commit

Permalink
core: support logprobs with multi-step scheduling (#963)
Browse files Browse the repository at this point in the history
* deferred sampler results

* fix imports and implement within multistep worker

* update tests

* fix test

* fix sequence test

* fix unrelated gguf ruff issue
  • Loading branch information
AlpinDale authored Dec 22, 2024
1 parent 34e8606 commit 0dfa6b6
Show file tree
Hide file tree
Showing 108 changed files with 917 additions and 424 deletions.
66 changes: 0 additions & 66 deletions aphrodite/common/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,72 +1046,6 @@ def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"


class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""

outputs: List[CompletionSequenceGroupOutput]

# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None

# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None

# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None

# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None

# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None

# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None

# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None

def __getitem__(self, idx: int):
return self.outputs[idx]

def __setitem__(self, idx: int, value):
self.outputs[idx] = value

def __len__(self):
return len(self.outputs)

def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs

def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")


class PoolerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
Expand Down
7 changes: 4 additions & 3 deletions aphrodite/engine/aphrodite_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from aphrodite.common.pooling_params import PoolingParams
from aphrodite.common.sampling_params import SamplingParams
from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
ExecuteModelRequest, SamplerOutput,
Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
ExecuteModelRequest, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from aphrodite.common.utils import Counter, Device
from aphrodite.engine.args_tools import EngineArgs
from aphrodite.engine.metrics_types import StatLoggerBase, Stats
Expand All @@ -42,6 +42,7 @@
SingletonPromptInputs)
from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput
from aphrodite.multimodal import MultiModalDataDict
from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/engine/async_aphrodite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
from aphrodite.common.pooling_params import PoolingParams
from aphrodite.common.sampling_params import SamplingParams
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.utils import print_warning_once
from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
DecoderPromptComponents,
Expand All @@ -29,6 +29,7 @@
SingletonPromptInputs)
from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput
from aphrodite.processing.scheduler import SchedulerOutputs
from aphrodite.prompt_adapter.request import PromptAdapterRequest
from aphrodite.transformers_utils.tokenizer import AnyTokenizer
Expand Down
14 changes: 11 additions & 3 deletions aphrodite/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from aphrodite.common.utils import Counter
from aphrodite.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from aphrodite.engine.output_processor.single_step import (
single_step_process_prompt_logprob)
from aphrodite.engine.output_processor.stop_checker import StopChecker
from aphrodite.processing.scheduler import Scheduler
from aphrodite.transformers_utils.detokenizer import Detokenizer
Expand Down Expand Up @@ -46,9 +48,15 @@ def __init__(

def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
# TODO: Prompt logprob currently not implemented in multi step
# workers.
self._log_prompt_logprob_unsupported_warning_once()
"""Process prompt logprobs associated with each step of a multi-step-
scheduled computation.
Args:
seq_group: the outputs are associated with this :class:`SequenceGroup`
outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
"""
for output in outputs:
# Concatenate single-step prompt logprob processing results.
single_step_process_prompt_logprob(self, seq_group, output)

@staticmethod
@functools.lru_cache()
Expand Down
61 changes: 44 additions & 17 deletions aphrodite/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,42 @@
from aphrodite.transformers_utils.detokenizer import Detokenizer


def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: SequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs.
Args:
sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
prompt_logprobs = output.prompt_logprobs

# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []

assert hasattr(sg_output_proc, 'detokenizer')
if (seq_group.sampling_params.detokenize
and sg_output_proc.detokenizer):
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))

seq_group.prompt_logprobs.extend(prompt_logprobs)


class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
Expand Down Expand Up @@ -57,25 +93,16 @@ def process_outputs(self, sequence_group: SequenceGroup,

def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process prompt logprobs associated with one step of a single-step-
scheduled computation.
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs

# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []
if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))
seq_group.prompt_logprobs.extend(prompt_logprobs)
single_step_process_prompt_logprob(self, seq_group, output)

def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput,
Expand Down
4 changes: 2 additions & 2 deletions aphrodite/engine/output_processor/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Sequence as GenericSequence
from typing import Union

from aphrodite.common.sequence import (PoolerOutput, SamplerOutput,
SequenceGroupOutput)
from aphrodite.common.sequence import PoolerOutput, SequenceGroupOutput
from aphrodite.modeling.layers.sampler import SamplerOutput


def create_output_by_sequence_group(
Expand Down
2 changes: 1 addition & 1 deletion aphrodite/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
from aphrodite.common.pooling_params import PoolingParams
from aphrodite.common.sampling_params import SamplingParams
from aphrodite.common.sequence import SamplerOutput
from aphrodite.inputs.data import PromptInputs
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput
from aphrodite.processing.scheduler import SchedulerOutputs
from aphrodite.prompt_adapter.request import PromptAdapterRequest

Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import aphrodite.common.envs as envs
from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.utils import (GiB_bytes, get_aphrodite_instance_id,
get_distributed_init_method, get_open_port,
make_async)
Expand All @@ -16,6 +16,7 @@
ResultHandler,
WorkerMonitor)
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput
from aphrodite.prompt_adapter.request import PromptAdapterRequest
from aphrodite.task_handler.worker_base import WorkerWrapperBase

Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from loguru import logger

from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.executor.executor_base import ExecutorAsyncBase
from aphrodite.executor.gpu_executor import GPUExecutor
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput


class DistributedGPUExecutor(GPUExecutor):
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
LoRAConfig, ModelConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput
from aphrodite.prompt_adapter.request import PromptAdapterRequest


Expand Down
4 changes: 2 additions & 2 deletions aphrodite/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from loguru import logger

from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
SamplerOutput)
from aphrodite.common.sequence import ExecuteModelRequest, PoolerOutput
from aphrodite.common.utils import (get_distributed_init_method, get_ip,
get_open_port, make_async)
from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput
from aphrodite.prompt_adapter.request import PromptAdapterRequest
from aphrodite.task_handler.worker_base import WorkerBase, WorkerWrapperBase

Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from loguru import logger

from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.utils import (_run_task_with_lock,
cuda_device_count_stateless,
get_aphrodite_instance_id,
Expand All @@ -21,6 +21,7 @@
from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler,
WorkerMonitor)
from aphrodite.modeling.layers.sampler import SamplerOutput
from aphrodite.triton_utils import maybe_set_triton_cache_manager


Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/neuron_executor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Set, Tuple

from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.utils import (get_distributed_init_method, get_ip,
get_open_port, make_async)
from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput


class NeuronExecutor(ExecutorBase):
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/openvino_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import aphrodite.common.envs as envs
from aphrodite.common.config import CacheConfig, ModelConfig
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.utils import (GiB_bytes, get_distributed_init_method,
get_ip, get_open_port, make_async)
from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput

APHRODITE_OPENVINO_KVCACHE_SPACE = envs.APHRODITE_OPENVINO_KVCACHE_SPACE
APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION = (
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from loguru import logger

import aphrodite.common.envs as envs
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.utils import (_run_task_with_lock,
get_aphrodite_instance_id,
get_distributed_init_method, get_ip,
Expand All @@ -17,6 +17,7 @@
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from aphrodite.executor.msgspec_utils import encode_hook
from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
from aphrodite.modeling.layers.sampler import SamplerOutput

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/ray_tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from loguru import logger

import aphrodite.common.envs as envs
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.utils import (get_aphrodite_instance_id,
get_distributed_init_method, get_ip,
get_open_port, make_async)
from aphrodite.executor.executor_base import ExecutorAsyncBase
from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
from aphrodite.executor.tpu_executor import TPUExecutor
from aphrodite.modeling.layers.sampler import SamplerOutput

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/executor/tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import torch
from loguru import logger

from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.utils import (get_distributed_init_method, get_ip,
get_open_port, make_async)
from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from aphrodite.lora.request import LoRARequest
from aphrodite.modeling.layers.sampler import SamplerOutput


class TPUExecutor(ExecutorBase):
Expand Down
4 changes: 2 additions & 2 deletions aphrodite/executor/xpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
LoRAConfig, ModelConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
SamplerOutput)
from aphrodite.common.sequence import ExecuteModelRequest, PoolerOutput
from aphrodite.common.utils import make_async
from aphrodite.executor.executor_base import ExecutorAsyncBase
from aphrodite.executor.gpu_executor import GPUExecutor
from aphrodite.modeling.layers.sampler import SamplerOutput
from aphrodite.task_handler.worker_base import WorkerBase


Expand Down
Loading

0 comments on commit 0dfa6b6

Please sign in to comment.