Skip to content

Commit

Permalink
spec decode: match the original rank computation impl for spec decodi…
Browse files Browse the repository at this point in the history
…ng (#954)
  • Loading branch information
AlpinDale authored Dec 21, 2024
1 parent 2aabf8f commit 564d197
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
4 changes: 2 additions & 2 deletions aphrodite/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def get_sampled_token_logprobs(
sampled_token_ids, ]
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
-1, -1, vocab_size)
sampled_token_ids_ranks = (logprob_tensor >=
expanded_selected_logprobs).sum(-1)
sampled_token_ids_ranks = (logprob_tensor >
expanded_selected_logprobs).sum(-1).add_(1)

return sampled_token_ids_ranks, selected_logprobs

Expand Down
19 changes: 18 additions & 1 deletion tests/spec_decode/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from aphrodite.common.sequence import SequenceGroupMetadata, get_all_seq_ids
from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
from aphrodite.modeling.layers.sampler import _get_ranks
from aphrodite.modeling.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from aphrodite.spec_decode.util import split_batch_by_proposal_len
from aphrodite.spec_decode.util import (get_sampled_token_logprobs,
split_batch_by_proposal_len)


def test_get_all_seq_ids():
Expand Down Expand Up @@ -126,3 +128,18 @@ def mock_spec_decode_sampler(acceptance_sampler_method):
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")


def test_get_sampled_token_logprobs():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor = torch.tensor(
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
sampled_token_tensor = torch.tensor([[1,
0]]) # shape (num_steps, batch_size)
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
sampled_token_tensor)
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
sampled_token_tensor.reshape(-1))
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)

0 comments on commit 564d197

Please sign in to comment.