Skip to content

Commit

Permalink
allow generation of %ignore terminals (#570)
Browse files Browse the repository at this point in the history
Fixes #565

---------

Co-authored-by: Andrew Lapp <[email protected]>
  • Loading branch information
lapp0 and Andrew Lapp authored Feb 6, 2024
1 parent 407d76e commit 0b9528e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
3 changes: 3 additions & 0 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:

interactive = self.parser.parse_interactive(self.generation)
interactive.exhaust_lexer()

options = {self.terminal_regexps[x] for x in interactive.accepts()}
# add %ignore terminals
options |= {self.terminal_regexps[x] for x in self.parser.lexer_conf.ignore}

if self.terminal_regexps["$END"] in options:
options.remove(self.terminal_regexps["$END"])
Expand Down
62 changes: 62 additions & 0 deletions tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,68 @@ def decode(self, token_ids):
assert set(fsm.allowed_token_ids(state=state)) == {3}


def test_cfg_ignore_directive():
class MockTokenizer:
vocabulary = {"a": 1, " ": 2, "eos": 3}
special_tokens = {"eos"}
eos_token = "eos"
eos_token_id = 3

def convert_token_to_string(self, token):
return token

@property
def inverse_vocabulary(self):
return {v: k for k, v in self.vocabulary.items()}

def decode(self, token_ids):
return [self.inverse_vocabulary[t] for t in token_ids]

cfg_str = """
start: LETTER+
LETTER: "a"
WS: " "
%ignore WS
"""
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

state = 0

assert set(fsm.allowed_token_ids(state=0)) == {1, 2}
state = fsm.next_state(state=0, token_id=2)
assert fsm.generation == " "
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=0)) == {1, 2}
state = fsm.next_state(state=0, token_id=1)
assert fsm.generation == " a"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
state = fsm.next_state(state=state, token_id=2)
assert fsm.generation == " a "
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
state = fsm.next_state(state=state, token_id=2)
assert fsm.generation == " a "
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
state = fsm.next_state(state=state, token_id=1)
assert fsm.generation == " a a"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
state = fsm.next_state(state=state, token_id=3)
assert fsm.generation == " a a"
assert fsm.is_final_state(state)

# once eos generated, can only terminate
assert set(fsm.allowed_token_ids(state=state)) == {3}


def test_cfg_multitoken_terminal():
class MockTokenizer:
vocabulary = {"a": 1, "b": 2, "eos": 3}
Expand Down

0 comments on commit 0b9528e

Please sign in to comment.