Skip to content

Commit

Permalink
Model: Switch logprobs to use post-sampling
Browse files Browse the repository at this point in the history
Previously, pre-sampling logprobs were used from the raw logits,
but newer versions of exl2 allow for returning token probs post-sampling.
Convert these to logprobs and send to the user.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Feb 12, 2024
1 parent a79c42f commit 36ea794
Showing 1 changed file with 47 additions and 36 deletions.
83 changes: 47 additions & 36 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,15 @@ def unload(self, loras_only: bool = False):
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string"""

return self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)[0].tolist()
return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)
.flatten()
.tolist()
)

def decode_tokens(self, ids: List[int], **kwargs):
"""Wrapper to decode tokens from a list of IDs"""
Expand All @@ -484,35 +488,24 @@ def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
"unk_token": self.tokenizer.unk_token,
}

def get_logprobs(self, logits: torch.Tensor, max_logprobs: int):
normalized_logits = torch.log_softmax(logits, dim=-1)
top_values, top_ids = torch.topk(normalized_logits, max_logprobs, dim=-1)

def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
top_tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
),
top_ids[0].tolist(),
token_ids.flatten().tolist(),
)
)
top_values = top_values[0].tolist()

return dict(zip_longest(top_tokens, top_values))

def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
normalized_probs = torch.log(token_probs)
top_values = torch.log(token_probs).flatten().tolist()

tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
),
token_ids[0].tolist(),
)
# Cannot return -inf in JSON
cleaned_values = list(
map(lambda value: -1000 if value == float("-inf") else value, top_values)
)

return dict(zip_longest(tokens, normalized_probs[0].tolist()))
return dict(zip_longest(top_tokens, cleaned_values))

def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
Expand Down Expand Up @@ -707,7 +700,10 @@ def generate_gen(self, prompt: str, **kwargs):
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias")

# Logprobs
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
self.generator.return_top_tokens = request_logprobs

# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
Expand Down Expand Up @@ -827,14 +823,20 @@ def generate_gen(self, prompt: str, **kwargs):
if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens

# Generate
chunk, eos, tokens, token_probs, logits = self.generator.stream()
# Run dict generation
# Guarantees return of chunk, eos, and chunk_token_ids
raw_generation = self.generator.stream_ex()

if token_healing:
# Extract healed token
ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False

# Get parameters that will always exist
chunk = raw_generation["chunk"]
eos = raw_generation["eos"]
tokens = raw_generation["chunk_token_ids"]

save_tokens = torch.cat(
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
)
Expand All @@ -858,17 +860,26 @@ def generate_gen(self, prompt: str, **kwargs):
}

if request_logprobs > 0:
# Get sampled token probs
if token_probs.numel() > 0 and tokens.numel() > 0:
generation["token_probs"] = self.get_token_probs(
tokens, token_probs
)

# Get logprob choices
if logits.numel() > 0:
generation["logprobs"] = self.get_logprobs(
logits, request_logprobs
)
# Get top tokens and probs
top_tokens = unwrap(
raw_generation.get("top_tokens"),
torch.empty((1, 0, 1), dtype=torch.long),
)

top_probs = unwrap(
raw_generation.get("top_probs"),
torch.empty((1, 0, 1), dtype=torch.float),
)

if top_tokens.numel() > 0 and top_probs.numel() > 0:
logprobs = self.get_logprobs(top_tokens, top_probs)
generation["logprobs"] = logprobs

# The first logprob is the selected token prob
generation["token_probs"] = {
token: logprobs[token]
for token in list(logprobs.keys())[:1]
}

yield generation
full_response += chunk_buffer
Expand Down

0 comments on commit 36ea794

Please sign in to comment.