From 1bc9cae9c37ca12ddcf0c68faf806fe9efec6dcc Mon Sep 17 00:00:00 2001 From: venondev Date: Fri, 15 Dec 2023 16:19:29 +0100 Subject: [PATCH] Fix prompt removal logic and update test case --- outlines/generate/api.py | 15 ++++++++++----- tests/generate/test_generator.py | 13 +++++++------ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 7130bf709..ec6ee951a 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -60,8 +60,6 @@ def __call__( if isinstance(prompts, str): prompts = [prompts] - prompt_lengths = [len(prompt) for prompt in prompts] - if rng is None: rng = torch.Generator(device=self.device) rng.seed() @@ -82,10 +80,17 @@ def __call__( except StopIteration: break - sequences = self.tokenizer.decode(last_state.token_ids) - generated = [ - sequence[length:] for sequence, length in zip(sequences, prompt_lengths) + # Get the number of tokens in the prompts + prompt_token_ids = init_state[0] + prompt_lengths = [len(prompt_token_ids[i]) for i in range(len(prompts))] + + # Remove the prompts from the generated sequences + token_ids = [ + cur_token_ids[length:] + for cur_token_ids, length in zip(last_state.token_ids, prompt_lengths) ] + + generated = self.tokenizer.decode(token_ids) formatted = [self.format_sequence(sequence) for sequence in generated] return formatted if len(formatted) > 1 else formatted[0] diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index 1e313ff2e..92a30b067 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -22,27 +22,28 @@ def test_sequence_generator_class(): class MockFSM: def next_state(self, state, next_token_ids): - return 0 + return 4 def allowed_token_ids(self, _): - return [] + return [4] def is_final_state(self, _): return True class MockTokenizer: def encode(self, _): - return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) + # Input: "test" + return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1, 1]]) - def decode(self, _): - return ["testx"] + def decode(self, tokens): + return ["testx"[i] for i in tokens] class MockModel: def __init__(self): self.tokenizer = MockTokenizer() def __call__(*_): - return torch.tensor([[0, 1, 2, 3]], dtype=torch.float), None + return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None def sampler(biased_logits, *_): return torch.argmax(biased_logits, keepdims=True)