From 564d1976870a17ab3be9c2dcddec30f62433a058 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Sat, 21 Dec 2024 14:57:21 -0800 Subject: [PATCH] spec decode: match the original rank computation impl for spec decoding (#954) --- aphrodite/spec_decode/util.py | 4 ++-- tests/spec_decode/test_utils.py | 19 ++++++++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/aphrodite/spec_decode/util.py b/aphrodite/spec_decode/util.py index 5fee22c7a..e0735f5a7 100644 --- a/aphrodite/spec_decode/util.py +++ b/aphrodite/spec_decode/util.py @@ -43,8 +43,8 @@ def get_sampled_token_logprobs( sampled_token_ids, ] expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( -1, -1, vocab_size) - sampled_token_ids_ranks = (logprob_tensor >= - expanded_selected_logprobs).sum(-1) + sampled_token_ids_ranks = (logprob_tensor > + expanded_selected_logprobs).sum(-1).add_(1) return sampled_token_ids_ranks, selected_logprobs diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 5a06f3ee9..90df760e4 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -5,9 +5,11 @@ from aphrodite.common.sequence import SequenceGroupMetadata, get_all_seq_ids from aphrodite.modeling.layers.rejection_sampler import RejectionSampler +from aphrodite.modeling.layers.sampler import _get_ranks from aphrodite.modeling.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) -from aphrodite.spec_decode.util import split_batch_by_proposal_len +from aphrodite.spec_decode.util import (get_sampled_token_logprobs, + split_batch_by_proposal_len) def test_get_all_seq_ids(): @@ -126,3 +128,18 @@ def mock_spec_decode_sampler(acceptance_sampler_method): return sampler else: raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") + + +def test_get_sampled_token_logprobs(): + """Verify get_sampled_token_logprobs returns consistent rankings + with regular get_ranks when probabilities match exactly. + """ + logprob_tensor = torch.tensor( + [[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size) + sampled_token_tensor = torch.tensor([[1, + 0]]) # shape (num_steps, batch_size) + ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor, + sampled_token_tensor) + ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)), + sampled_token_tensor.reshape(-1)) + assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)