Skip to content

Commit

Permalink
core: support prompt logprobs in multi-step (#1060)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Jan 1, 2025
1 parent 9a7d551 commit 58aff37
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 58 deletions.
70 changes: 51 additions & 19 deletions aphrodite/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,29 +604,61 @@ def _pythonize_sampler_output(
assert model_input.frozen_model_input is not None
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input.sampling_metadata is not None
sampling_metadata = frozen_model_input.sampling_metadata
# samples generation should have been skipped
assert not output.outputs
pinned_buffer = pinned_sampled_token_buffer[: model_input.num_queries]
# CPU GPU sync
pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
# We guarantee output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However we should check whether logprobs pythonization may
# be skipped entirely, i.e. because no logprobs were requested
# or pythonization was not deferred. To that end,
#
# * `prompt_logprobs_are_requested_for_prefill` signals that
# there are *any* prefill-phase requests which specify that
# prompt logprobs should be returned.
#
# * `any_logprobs_are_requested` signals that there are any
# requests which (1) specify that sample logprobs should be
# returned, or (2) are in the prefill phase AND specify that
# prompt logprobs should be returned.
#
# Later on, these flags cause adjustments to the pythonization
# process to accommodate logprobs.
seq_groups = sampling_metadata.seq_groups
prompt_logprobs_are_requested_for_prefill = any([
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
for sg in seq_groups
])
any_logprobs_are_requested = (
prompt_logprobs_are_requested_for_prefill
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
if prompt_logprobs_are_requested_for_prefill:
# CPU GPU sync, after gathering *only* sampled tokens (since
# requesting prompt logprobs leads `sampled_token_ids` to
# include prompt token ids in addition to sampled token ids.)
sample_idx_tensor = torch.tensor(
[sdx for sg in seq_groups for sdx in sg.sample_indices])
pinned_buffer = pinned_buffer.copy_(
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
else:
# CPU GPU sync
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
non_blocking=False)
# this will not block as the tensors are already on CPU
samples_list = pinned_buffer.tolist()
sampling_metadata = frozen_model_input.sampling_metadata

skip_sampler_cpu_output = (
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)

# We are guaranteed output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups = sampling_metadata.seq_groups
logprobs_are_requested = any([
sg.sampling_params.logprobs is not None
or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
])
# *Don't* skip logprobs pythonization *if*:
# * Any requests require logprobs to be returned in this
# iteration AND
# * These requests are being scheduled in a fashion which
# defers pythonization (i.e. multi-step scheduling.)
do_pythonize_logprobs = (skip_sampler_cpu_output
and logprobs_are_requested)
and any_logprobs_are_requested)
(
prompt_logprobs,
sample_logprobs,
Expand All @@ -651,7 +683,7 @@ def _pythonize_sampler_output(
prompt_logprobs[sgdx],
sample_logprobs[sgdx],
)
elif logprobs_are_requested:
elif any_logprobs_are_requested:
(
group_prompt_logprobs,
group_sample_logprobs,
Expand All @@ -677,7 +709,7 @@ def _pythonize_sampler_output(
)
seq_output.parent_seq_id = seq_ids[parent_id]
seq_output.output_token = next_token_id
if logprobs_are_requested:
if any_logprobs_are_requested:
seq_output.logprobs = group_sample_logprobs[tdx]
else:
logprobs = next(iter(seq_output.logprobs.values()))
Expand All @@ -691,19 +723,19 @@ def _pythonize_sampler_output(
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if logprobs_are_requested else {
if any_logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
if cache is not None:
completion_seq_group_output.prompt_logprobs = \
group_prompt_logprobs if logprobs_are_requested else None
group_prompt_logprobs if any_logprobs_are_requested else None
output.outputs.append(completion_seq_group_output)
else:
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs, (group_prompt_logprobs
if logprobs_are_requested else None)))
if any_logprobs_are_requested else None)))
assert len(output.outputs) > 0
83 changes: 51 additions & 32 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from aphrodite.assets.video import VideoAsset
from aphrodite.common.config import TokenizerPoolConfig
from aphrodite.common.outputs import RequestOutput
from aphrodite.common.sequence import SampleLogprobs
from aphrodite.common.utils import (STR_DTYPE_TO_TORCH_DTYPE,
cuda_device_count_stateless, identity,
is_cpu)
Expand All @@ -36,6 +35,8 @@
initialize_model_parallel)
from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from tests.models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)

_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
Expand Down Expand Up @@ -469,7 +470,7 @@ def generate_greedy_logprobs_limit(
audios: Optional[PromptAudioInput] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
) -> List[TokensTextLogprobs]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
Expand Down Expand Up @@ -525,7 +526,7 @@ def generate_encoder_decoder_greedy_logprobs_limit(
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
) -> List[TokensTextLogprobs]:
'''
Greedy logprobs generation for Aphrodite encoder/decoder models
'''
Expand Down Expand Up @@ -653,14 +654,16 @@ def generate(
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: List[RequestOutput],
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
) -> List[TokensTextLogprobsPromptLogprobs]:
outputs: List[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
assert len(req_output.outputs) > 0
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
outputs.append((output_ids, output_str, output_logprobs,
req_output.prompt_logprobs))
return outputs

def generate_w_logprobs(
Expand All @@ -670,7 +673,8 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
assert sampling_params.logprobs is not None

if images is not None:
Expand All @@ -691,23 +695,35 @@ def generate_w_logprobs(
for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video}
print(f"[INPUTS!!!!]: {inputs}, {sampling_params}")

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)

def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
'''
Logprobs generation for Aphrodite encoder/decoder models
'''

assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)

def generate_greedy(
self,
Expand All @@ -725,44 +741,47 @@ def generate_greedy_logprobs(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)

return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
stop_token_ids=stop_token_ids)
return self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)

def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs)
num_prompt_logprobs: Optional[int] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
)
'''
Greedy logprobs generation for Aphrodite encoder/decoder models
'''

outputs = self.generate_encoder_decoder_w_logprobs(
return self.generate_encoder_decoder_w_logprobs(
encoder_decoder_prompts, greedy_logprobs_params)

return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

def generate_beam_search(
self,
prompts: List[str],
Expand Down
Loading

0 comments on commit 58aff37

Please sign in to comment.