Skip to content

Commit

Permalink
Fix custom sampler setup and add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 12, 2023
1 parent fb465db commit 7083003
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
from outlines.text.generate.sample import multinomial

self.sampler = multinomial
else:
self.sampler = sampler

def create_proposal(
self, generated_token_ids: torch.LongTensor, logits: torch.DoubleTensor
Expand Down
30 changes: 30 additions & 0 deletions tests/text/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,33 @@ def test_transformers_reduced_vocabulary_caching():
vocab2 = reduced_vocabulary(tokenizer2)

assert vocab2 is vocab


def test_custom_sampler():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"

model = models.transformers(model_name)

seen = False
target_token_ids = model.tokenizer.encode(["c"])[0]

def biased_sampler(
logits: torch.DoubleTensor, samples: int, *_
) -> torch.DoubleTensor:
nonlocal seen

if not seen:
seen = True
return target_token_ids
else:
return torch.tensor([[model.tokenizer.eos_token_id]])

generator = generate.choice(model, ["a", "b", "c"], sampler=biased_sampler)
sequence = generator(
"""What is 1+1?
a. 3
b. 4
c. 2"""
)

assert sequence == "c"

0 comments on commit 7083003

Please sign in to comment.