Skip to content

Commit

Permalink
Merge pull request #16 from probcomp/gg/sample_word_2
Browse files Browse the repository at this point in the history
Additional submodels and masks
  • Loading branch information
alex-lew authored Jul 18, 2024
2 parents 8e86427 + ebaed8d commit cd2b3ea
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 6 deletions.
86 changes: 86 additions & 0 deletions hfppl/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,89 @@ async def sample_word(self, context, max_tokens=5, allow_punctuation=True):
punctuation = context.lm.vocab[punctuation_token.token_id]

return word, punctuation


@submodel
async def sample_word_2(
self,
context,
max_chars: int = None,
allow_mid_punctuation: bool = True,
allow_end_punctuation: bool = True,
):
"""Sample a word from the `LMContext` object `context`.
Unlike sample_word() above, this method allows for character-level control over the length of the word.
It also allows for control over the presence of punctuation in the middle and at the end of the word.
Args:
max_chars (int): Maximum number of characters in the word. If None, the model will sample a word of any length.
allow_mid_punctuation (bool): If True, the model may sample punctuation in the middle of the word.
allow_end_punctuation (bool): If True, the model may sample punctuation at the end of the word.
Returns:
Tuple[str, str]: The sampled word and punctuation
"""

# This approach sometimes breaks with max_chars = 1
if max_chars is not None:
assert max_chars > 1

last_token = context.lm.vocab[context.tokens[-1]] if len(context.tokens) > 0 else ""
last_character = last_token[-1] if len(last_token) > 0 else ""
needs_space = last_character not in string.whitespace and last_character not in [
"-",
"'",
'"',
]
if needs_space:
starts_word_mask = context.lm.masks.STARTS_NEW_WORD
else:
starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD

# Force model to start a new word
await self.observe(context.mask_dist(starts_word_mask), True)

word = ""
while True:
# Force model to sample a token with an appropriate number of characters
if max_chars is not None:
await self.observe(
context.mask_dist(
context.lm.masks.MAX_TOKEN_LENGTH[max_chars - len(word.strip())]
),
True,
)

token = await self.sample(context.next_token())
word += context.lm.vocab[token.token_id]

# If we ran out of chars, break
if max_chars is not None and len(word.strip()) >= max_chars:
await self.observe(
context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False
)
break

# If the model wants to end the word, break
if not (
await self.sample(
context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD)
)
):
break

# Sample punctuation, if desired
punctuation = ""

mask = set()
if allow_mid_punctuation:
mask = mask | context.lm.masks.MID_PUNCTUATION
if allow_end_punctuation:
mask = mask | context.lm.masks.END_PUNCTUATION

if mask and await self.sample(context.mask_dist(mask)):
punctuation_token = await self.sample(context.next_token())
punctuation = context.lm.vocab[punctuation_token.token_id]

return word, punctuation
17 changes: 13 additions & 4 deletions hfppl/distributions/lmcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class LMContext:
show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`.
"""

def __init__(self, lm, prompt, temp=1.0):
def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True):
"""Create a new `LMContext` with a given prompt and temperature.
Args:
Expand All @@ -126,7 +126,9 @@ def __init__(self, lm, prompt, temp=1.0):
self.temp = temp
self.model_mask = lm.masks.ALL_TOKENS
self.prompt_string_length = len(lm.tokenizer.decode(self.tokens))
self.show_prompt = False
self.prompt_token_count = len(self.tokens)
self.show_prompt = show_prompt
self.show_eos = show_eos

def next_token(self):
"""Distribution over the next token.
Expand All @@ -148,10 +150,17 @@ def mask_dist(self, mask):
"""
return LMTokenMask(self, mask)

@property
def token_count(self):
return len(self.tokens) - self.prompt_token_count

def __str__(self):
base = 0 if self.show_prompt else self.prompt_string_length
full_string = self.lm.tokenizer.decode(self.tokens)
return full_string[base:]
if not self.show_prompt:
full_string = full_string[self.prompt_string_length:]
if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token):
full_string = full_string[:-len(self.lm.tokenizer.eos_token)]
return full_string

def __deepcopy__(self, memo):
cpy = type(self).__new__(type(self))
Expand Down
33 changes: 31 additions & 2 deletions hfppl/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,37 @@ def __init__(self, lm):
for (i, v) in enumerate(lm.vocab)
if all(c in "'" or c.isalpha() for c in v)
)
self.PUNCTUATION = set(i for (i, v) in enumerate(lm.vocab) if v in ',:;.!?"-')
self.END_SENTENCE_PUNCT = set(i for (i, v) in enumerate(lm.vocab) if v in ".!?")
self.MID_PUNCTUATION = set(
i for (i, v) in enumerate(lm.vocab) if v in (",", ":", ";", "-", '"')
)
self.END_PUNCTUATION = set(
i for (i, v) in enumerate(lm.vocab) if v in (".", "!", "?")
)
self.PUNCTUATION = self.MID_PUNCTUATION | self.END_PUNCTUATION
self.CONTAINS_WHITESPACE = set(
i
for (i, v) in enumerate(lm.vocab)
if any(c in string.whitespace for c in v)
)

self.MAX_TOKEN_LENGTH = self.precompute_token_length_masks(lm)

def precompute_token_length_masks(self, lm) -> Dict[int, Set[int]]:
"""Precompute masks for tokens of different lengths.
Each mask is a set of token ids that are of the given length or shorter."""
max_token_length = max([len(t) for t in lm.vocab])

masks = defaultdict(lambda: self.ALL_TOKENS)
masks[0] = set([lm.tokenizer.eos_token_id])
for token_length in range(1, max_token_length + 1):
masks[token_length] = set(
i
for (i, v) in enumerate(lm.vocab)
if len(v) <= token_length and i != lm.tokenizer.eos_token_id
)

return masks


class TokenSequence:
Expand Down

0 comments on commit cd2b3ea

Please sign in to comment.