Skip to content

Commit

Permalink
Update 'Getting Started' doc to use same example as README
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-lew committed Jul 12, 2024
1 parent f505208 commit d54a59f
Showing 1 changed file with 12 additions and 21 deletions.
33 changes: 12 additions & 21 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,48 +33,39 @@ To do so, we write subclass the [`Model`](hfppl.modeling.Model) class:
```python
# examples/no_e.py

from hfppl import Model, LMContext, TokenCategorical, CachedCausalLM
from hfppl import Model, LMContext, CachedCausalLM

# A LLaMPPL model subclasses the Model class
class MyModel(Model):

# The __init__ method is used to process arguments
# and initialize instance variables.
def __init__(self, lm, prompt, forbidden_letter):

# Always call the superclass's __init__.
super().__init__()

# A stateful context object for the LLM, initialized with the prompt
self.context = LMContext(lm, prompt)

self.eos_token = lm.tokenizer.eos_token_id

# The forbidden letter
self.forbidden_tokens = [i for (i, v) in enumerate(lm.vocab)
if forbidden_letter in v]

self.forbidden_tokens = set(i for (i, v) in enumerate(lm.vocab)
if forbidden_letter in v)
# The step method is used to perform a single 'step' of generation.
# This might be a single token, a single phrase, or any other division.
# Here, we generate one token at a time.
async def step(self):
# Sample a token from the LLM -- automatically extends `self.context`.
# We use `await` so that LLaMPPL can automatically batch language model calls.
token = await self.sample(self.context.next_token(),
proposal=self.proposal())

# Condition on the token not having the forbidden letter
self.condition(token.token_id not in self.forbidden_tokens)
# Condition on the next token *not* being a forbidden token.
await self.observe(self.context.mask_dist(self.forbidden_tokens), False)

# Sample the next token from the LLM -- automatically extends `self.context`.
token = await self.sample(self.context.next_token())

# Check for EOS or end of sentence
if token.token_id == self.context.lm.tokenizer.eos_token_id or str(token) in ['.', '!', '?']:
if token.token_id == self.eos_token or str(token) in ['.', '!', '?']:
# Finish generation
self.finish()

# Helper method to define a custom proposal
def proposal(self):
logits = self.context.next_token_logprobs.copy()
logits[self.forbidden_tokens] = -float('inf')
return TokenCategorical(self.context.lm, logits)

# To improve performance, a hint that `self.forbidden_tokens` is immutable
def immutable_properties(self):
return set(['forbidden_tokens'])
Expand Down

0 comments on commit d54a59f

Please sign in to comment.