Skip to content

Commit

Permalink
Model: Fix logprobs unwrapping
Browse files Browse the repository at this point in the history
Take a log of the token probs since they're already normalized which
reflects the proper value. Also, don't error out if a token prob
doesn't exist in the dict and return None instead from zip.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Feb 9, 2024
1 parent c7428f0 commit 43bba52
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""
import gc
from itertools import zip_longest
import pathlib
import time

Expand Down Expand Up @@ -486,9 +487,11 @@ def get_logprobs(self, logits: torch.Tensor, max_logprobs: int):
)
top_values = top_values[0].tolist()

return dict(zip(top_tokens, top_values, strict=True))
return dict(zip_longest(top_tokens, top_values))

def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
normalized_probs = torch.log(token_probs)

tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
Expand All @@ -498,7 +501,7 @@ def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
)
)

return dict(zip(tokens, token_probs[0].tolist(), strict=True))
return dict(zip_longest(tokens, normalized_probs[0].tolist()))

def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
Expand Down

0 comments on commit 43bba52

Please sign in to comment.