Skip to content

Commit

Permalink
api: add sampling/engine option to return only deltas or final output (
Browse files Browse the repository at this point in the history
…#1035)

* api: add sampling/engine option to return only deltas or final output

* fix streaming
  • Loading branch information
AlpinDale authored Dec 27, 2024
1 parent 1390915 commit 055c890
Show file tree
Hide file tree
Showing 11 changed files with 727 additions and 279 deletions.
79 changes: 53 additions & 26 deletions aphrodite/common/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Sequence as GenericSequence
from typing import Union

from aphrodite.common.sampling_params import RequestOutputKind
from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
SampleLogprobs, SequenceGroup,
SequenceStatus)
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: List[int],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
finished: bool,
Expand All @@ -113,19 +114,25 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids

@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
if seq_group.sampling_params is None:
def from_seq_group(cls,
seq_group: SequenceGroup) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")
# Get the top-n sequences.
n = seq_group.sampling_params.n
finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
seqs = seq_group.get_seqs()
if n == 1:
if len(seqs) == 1:
top_n_seqs = seqs
else:
if seq_group.sampling_params.use_beam_search:
# Get the top-n sequences.
n = sampling_params.n
if sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)
sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
Expand All @@ -135,26 +142,46 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.data._output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
]
include_logprobs = sampling_params.logprobs is not None
text_buffer_length = sampling_params.output_text_buffer_length
delta = sampling_params.output_kind == RequestOutputKind.DELTA
outputs = []
include_prompt = True
for seq in top_n_seqs:
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)
output_token_ids = seq.get_output_token_ids_to_return(delta)
output_logprobs = seq.output_logprobs if include_logprobs else None
if delta:
# Slice logprobs delta if applicable
if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):]
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > len(
output_token_ids):
include_prompt = False
outputs.append(
CompletionOutput(
seqs.index(seq), output_text, output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason))

# Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished()
if include_prompt:
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
return cls(
Expand Down
17 changes: 16 additions & 1 deletion aphrodite/common/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Sampling parameters for text generation."""
import copy
from enum import IntEnum
from enum import Enum, IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union

Expand All @@ -24,6 +24,15 @@ class SamplingType(IntEnum):
RANDOM_SEED = 2
BEAM = 3


class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOuputs
FINAL_ONLY = 2

class SamplerID(IntEnum):
# Mirror these in aphrodite/modeling/layers/sampler.py
# Values out of order to keep backwards compatibility
Expand Down Expand Up @@ -277,6 +286,7 @@ class SamplingParams(
dry_range: int = 0
skew: float = 0.0
sampler_priority: Optional[List[int]] = []
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
Expand Down Expand Up @@ -331,6 +341,7 @@ class SamplingParams(
"dry_range": 0,
"skew": 0.0,
"sampler_priority": [],
"output_kind": RequestOutputKind.CUMULATIVE,
}

def __post_init__(self) -> None:
Expand Down Expand Up @@ -513,6 +524,10 @@ def _verify_args(self) -> None:
f"{missing_names}"
)

if self.best_of != self.n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")

def _verify_beam_search(self) -> None:
if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam "
Expand Down
40 changes: 35 additions & 5 deletions aphrodite/common/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from array import array
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Tuple, Union, cast)
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast

import msgspec
import torch
Expand Down Expand Up @@ -397,6 +398,10 @@ def __init__(
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None

# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_text_offset: int = 0

# Used for incremental detokenization
self.prefix_offset = 0
self.read_offset = 0
Expand Down Expand Up @@ -452,11 +457,36 @@ def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0

def get_output_text_to_return(self, buffer_length: int):
def get_output_text_to_return(self, buffer_length: int,
delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished()
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
if not delta:
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
length = len(self.output_text)
if truncate:
length -= buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""

def get_output_token_ids_to_return(self,
delta: bool) -> GenericSequence[int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset
if last_offset < length:
self._last_token_ids_offset = length
return self.data._output_token_ids[last_offset:]
return ()

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
Expand Down
Loading

0 comments on commit 055c890

Please sign in to comment.