Skip to content

Commit

Permalink
Update LMContext
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegrand committed Jul 18, 2024
1 parent 3195a80 commit ebaed8d
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions hfppl/distributions/lmcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class LMContext:
show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`.
"""

def __init__(self, lm, prompt, temp=1.0):
def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True):
"""Create a new `LMContext` with a given prompt and temperature.
Args:
Expand All @@ -127,7 +127,8 @@ def __init__(self, lm, prompt, temp=1.0):
self.model_mask = lm.masks.ALL_TOKENS
self.prompt_string_length = len(lm.tokenizer.decode(self.tokens))
self.prompt_token_count = len(self.tokens)
self.show_prompt = False
self.show_prompt = show_prompt
self.show_eos = show_eos

def next_token(self):
"""Distribution over the next token.
Expand All @@ -154,9 +155,12 @@ def token_count(self):
return len(self.tokens) - self.prompt_token_count

def __str__(self):
base = 0 if self.show_prompt else self.prompt_string_length
full_string = self.lm.tokenizer.decode(self.tokens)
return full_string[base:]
if not self.show_prompt:
full_string = full_string[self.prompt_string_length:]
if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token):
full_string = full_string[:-len(self.lm.tokenizer.eos_token)]
return full_string

def __deepcopy__(self, memo):
cpy = type(self).__new__(type(self))
Expand Down

0 comments on commit ebaed8d

Please sign in to comment.