Skip to content

Commit

Permalink
fix cfg incomplete edge case; can now allow model to sample whether t…
Browse files Browse the repository at this point in the history
…o extend current terminal vs transition to next
  • Loading branch information
benlipkin authored and rlouf committed Jan 24, 2024
1 parent b4ab2f2 commit eb031f0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 41 deletions.
48 changes: 39 additions & 9 deletions examples/cfg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import outlines.generate as generate
import outlines.models as models

# examples from https://lark-parser.readthedocs.io/en/latest/examples/index.html

nlamb_grammar = """
nlamb_grammar = r"""
start: sentence
sentence: noun verb noun -> simple
Expand All @@ -21,7 +19,7 @@
%ignore WS
"""

calc_grammar = """
calc_grammar = r"""
?start: sum
| NAME "=" sum -> assign_var
Expand All @@ -38,22 +36,54 @@
| NAME -> var
| "(" sum ")"
%import common.CNAME -> NAME
%import common.NUMBER
%import common.LETTER -> NAME
%import common.INT -> NUMBER
%import common.WS_INLINE
%ignore WS_INLINE
"""

dyck_grammar = r"""
start: s
s: /a+/
| "(" s ")"
| "{" s "}"
| "[" s "]"
"""

json_grammar = r"""
?start: value
?value: object
| array
| string
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : string ":" value
inner: /([^"]|\\\")+/ |
string : "\"" inner "\""
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""

model = models.transformers("hf-internal-testing/tiny-random-gpt2")
batch_size = 10
for grammar in [nlamb_grammar, calc_grammar]:
generator = generate.cfg(model, grammar)
for grammar in [nlamb_grammar, calc_grammar, dyck_grammar, json_grammar]:
generator = generate.cfg(model, grammar, max_tokens=model.model.config.n_positions)
sequences = generator([" "] * batch_size)
for seq in sequences:
try:
parse = generator.fsm.parser.parse(seq)
assert parse is not None
print("SUCCESS", seq)
except Exception:
except Exception: # will also fail if goes over max_tokens / context window
print("FAILURE", seq)
33 changes: 25 additions & 8 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
self.allow_eos = False
self.regex_fsm: RegexFSM

self.check_last = False
self.proposal_last: List[int] = []
self.regex_fsm_last: RegexFSM

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -243,14 +247,21 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
if self.is_final_state(state):
return [self.tokenizer.eos_token_id]

proposal = []
if self.generation != "":
proposal = self.regex_fsm.allowed_token_ids(state)
if self.check_last:
proposer = self.regex_fsm_last
else:
proposer = self.regex_fsm
proposal += proposer.allowed_token_ids(state)
if self.tokenizer.eos_token_id not in proposal:
return proposal
if set(proposal) != {self.tokenizer.eos_token_id}:
if False: # TODO: THIS NEEDS TO BE SAMPLED
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
return proposal
self.check_last = False
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
if len(proposal) > 0:
self.check_last = True
self.proposal_last = proposal.copy()
self.regex_fsm_last = proposer

interactive = self.parser.parse_interactive(self.generation)
interactive.exhaust_lexer()
Expand All @@ -268,7 +279,7 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
self.reset_state = True

proposal = self.regex_fsm.allowed_token_ids(self.first_state)
proposal += self.regex_fsm.allowed_token_ids(self.first_state)
if self.allow_eos:
self.allow_eos = False
else:
Expand Down Expand Up @@ -296,12 +307,18 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""
if token_id == self.tokenizer.eos_token_id:
return self.final_state

self.generation += self.tokenizer.decode([token_id])[0]

if self.check_last:
if token_id in self.proposal_last:
return self.regex_fsm_last.next_state(state, token_id)
self.check_last = False

if self.reset_state:
self.reset_state = False
state = self.first_state

self.generation += self.tokenizer.decode([token_id])[0]

return self.regex_fsm.next_state(state, token_id)

def copy(self) -> "CFGFSM":
Expand Down
48 changes: 24 additions & 24 deletions tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def decode(self, token_ids):
assert set(fsm.allowed_token_ids(state=state)) == {3}


def test_cfg_multitoken_subexpr():
def test_cfg_multitoken_terminal():
class MockTokenizer:
vocabulary = {"a": 1, "b": 2, "eos": 3}
special_tokens = {"eos"}
Expand Down Expand Up @@ -193,17 +193,12 @@ def decode(self, token_ids):
assert fsm.is_final_state(state)


@pytest.mark.xfail(
strict=True,
reason="Current regex implementation is not complete",
raises=NotImplementedError,
)
def test_cfg_overlapping_subexpr():
def test_cfg_allow_both_extend_and_shift_terminal():
class MockTokenizer:
vocabulary = {"a": 1, "b": 2, "eos": 3}
vocabulary = {"(": 1, ")": 2, "a": 3, "eos": 4}
special_tokens = {"eos"}
eos_token = "eos"
eos_token_id = 3
eos_token_id = 4

def convert_token_to_string(self, token):
return token
Expand All @@ -216,28 +211,33 @@ def decode(self, token_ids):
return [self.inverse_vocabulary[t] for t in token_ids]

cfg_str = """
start: S
S: "a" | "b" | "aa" | "bb"
start: s
s: "(" s ")" | /a+/
"""
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1, 2}
assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1, 3}
state = fsm.next_state(state=fsm.first_state, token_id=1)
assert fsm.generation == "a"
assert fsm.generation == "("
assert not fsm.is_final_state(state)

# INTENDED LOGIC
# This will fail until we fix TODO raised in https://github.com/outlines-dev/outlines/pull/391
try:
assert set(fsm.allowed_token_ids(state=state)) == {1, 3}
except AssertionError:
raise NotImplementedError("TODO: fix this")
assert set(fsm.allowed_token_ids(state=state)) == {1, 3}
state = fsm.next_state(state=state, token_id=3)
assert fsm.generation == "(a"
assert not fsm.is_final_state(state)

# CURRENT LOGIC
# For now, the FSM can only generate the greedy completion, ending at "a", never "aa"
# This implementation is sound, and always terminates, but is not complete
assert set(fsm.allowed_token_ids(state=state)) == {3}
assert set(fsm.allowed_token_ids(state=state)) == {2, 3}
state = fsm.next_state(state=state, token_id=3)
assert fsm.generation == "a"
assert fsm.generation == "(aa"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {2, 3}
state = fsm.next_state(state=state, token_id=2)
assert fsm.generation == "(aa)"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {4}
state = fsm.next_state(state=state, token_id=4)
assert fsm.generation == "(aa)"
assert fsm.is_final_state(state)

0 comments on commit eb031f0

Please sign in to comment.