diff --git a/hfppl/distributions/lmcontext.py b/hfppl/distributions/lmcontext.py index 30f4d0b..16b8247 100644 --- a/hfppl/distributions/lmcontext.py +++ b/hfppl/distributions/lmcontext.py @@ -110,7 +110,7 @@ class LMContext: show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`. """ - def __init__(self, lm, prompt, temp=1.0): + def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True): """Create a new `LMContext` with a given prompt and temperature. Args: @@ -127,7 +127,8 @@ def __init__(self, lm, prompt, temp=1.0): self.model_mask = lm.masks.ALL_TOKENS self.prompt_string_length = len(lm.tokenizer.decode(self.tokens)) self.prompt_token_count = len(self.tokens) - self.show_prompt = False + self.show_prompt = show_prompt + self.show_eos = show_eos def next_token(self): """Distribution over the next token. @@ -154,9 +155,12 @@ def token_count(self): return len(self.tokens) - self.prompt_token_count def __str__(self): - base = 0 if self.show_prompt else self.prompt_string_length full_string = self.lm.tokenizer.decode(self.tokens) - return full_string[base:] + if not self.show_prompt: + full_string = full_string[self.prompt_string_length:] + if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token): + full_string = full_string[:-len(self.lm.tokenizer.eos_token)] + return full_string def __deepcopy__(self, memo): cpy = type(self).__new__(type(self))