Skip to content

Commit

Permalink
Make FSM return allowed tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 6, 2023
1 parent 05bd997 commit 858a9cb
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 38 deletions.
28 changes: 9 additions & 19 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class FSM(Protocol):
def forbidden_token_ids(self, state: FSMState) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> List[int]:
...

def next_state(self, state: FSMState, token_id: int) -> FSMState:
Expand Down Expand Up @@ -42,8 +42,8 @@ def __init__(
self.vocabulary = tokenizer.vocabulary.values()
self.final_states = {1}

def forbidden_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of forbidden tokens for the next step.
def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
In the final state the only allowed token is `stop_token_id`.
Expand All @@ -59,13 +59,9 @@ def forbidden_token_ids(self, state: FSMState) -> List[int]:
"""
if state == 0:
return []
return list(self.vocabulary)
else:
return [
token_id
for token_id in self.vocabulary
if token_id != self.stop_token_id
]
return [self.stop_token_id]

def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
Expand Down Expand Up @@ -137,8 +133,8 @@ def __init__(
self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id

def forbidden_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of forbidden tokens for the next step.
def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a
map from authorized tokens to the state in which the FSM needs to move
Expand All @@ -163,15 +159,9 @@ def forbidden_token_ids(self, state: FSMState) -> List[int]:
next_tokens_to_end_states = self.states_to_token_maps.get(state)

if next_tokens_to_end_states is None:
authorized_tokens = [self.end_token_id]
return [self.end_token_id]
else:
authorized_tokens = list(next_tokens_to_end_states.keys())

forbidden_tokens = [
token for token in self.vocabulary if token not in authorized_tokens
]

return list(forbidden_tokens)
return list(next_tokens_to_end_states.keys())

def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
Expand Down
12 changes: 7 additions & 5 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def sequence_generator(
"""
token_ids, attention_masks, kv_cache = init_state
while True:
logits_masks = get_logits_masks(fsm, fsm_states)
logits_masks = get_allowed_tokens(fsm, fsm_states)

next_token_ids, kv_cache, logits = token_generator(
token_ids,
Expand Down Expand Up @@ -171,7 +171,7 @@ def get_next_fsm_states(
]


def get_logits_masks(fsm: "FSM", fsm_states: List[FSMState]) -> torch.Tensor:
def get_allowed_tokens(fsm: "FSM", fsm_states: List[FSMState]) -> torch.Tensor:
"""Get the new instructions for each sequence from the finite-state machine.
Parameters
Expand All @@ -183,10 +183,10 @@ def get_logits_masks(fsm: "FSM", fsm_states: List[FSMState]) -> torch.Tensor:
Returns
-------
A nested list that contains the ids of the logits to bias.
A nested list that contains the ids of the logits to keep.
"""
return [fsm.forbidden_token_ids(state) for state in fsm_states]
return [fsm.allowed_token_ids(state) for state in fsm_states]


def is_generation_finished(fsm: "FSM", fsm_states: List[FSMState]) -> bool:
Expand Down Expand Up @@ -306,5 +306,7 @@ def bias_logits(
"""
for i, ids in enumerate(ids_to_mask):
logits[i, ids] = -math.inf
mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device)
mask[ids] = 0
logits[i] = logits[i] + mask
return logits
6 changes: 3 additions & 3 deletions tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class MockTokenizer:

fsm = StopAtTokenFSM(MockTokenizer(), 2)

assert fsm.forbidden_token_ids(0) == []
assert fsm.forbidden_token_ids(1) == [1]
assert fsm.allowed_token_ids(0) == [1, 2]
assert fsm.allowed_token_ids(1) == [2]
assert fsm.next_state(0, 2) == 1
assert fsm.next_state(0, 1) == 0
assert fsm.is_final_state(0) is False
Expand Down Expand Up @@ -46,7 +46,7 @@ def convert_token_to_string(self, token):
fsm = RegexFSM(regex_str, tokenizer)

assert fsm.states_to_token_maps == {0: {1: 1}}
assert fsm.forbidden_token_ids(state=0) == [2, 3]
assert fsm.allowed_token_ids(state=0) == [1]
assert fsm.next_state(state=0, token_id=1) == 1
assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1

Expand Down
22 changes: 11 additions & 11 deletions tests/generate/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from outlines.generate.generator import (
bias_logits,
expand_attention_masks,
get_logits_masks,
get_allowed_tokens,
get_next_fsm_states,
init_generator_state,
is_generation_finished,
Expand All @@ -24,7 +24,7 @@ class MockFSM:
def next_state(self, state, next_token_ids):
return 0

def forbidden_token_ids(self, _):
def allowed_token_ids(self, _):
return []

def is_final_state(self, _):
Expand Down Expand Up @@ -73,7 +73,7 @@ class MockFSM:
def next_state(self, state, next_token_ids):
return 0

def forbidden_token_ids(self, _):
def allowed_token_ids(self, _):
return []

def is_final_state(self, _):
Expand Down Expand Up @@ -108,7 +108,7 @@ class MockFSM:
def next_state(self, state, next_token_ids):
return 0

def forbidden_token_ids(self, _):
def allowed_token_ids(self, _):
return []

def is_final_state(self, _):
Expand Down Expand Up @@ -151,7 +151,7 @@ class MockFSM:
def next_state(self, state, next_token_ids):
return FSMState(state + 1)

def forbidden_token_ids(self, _):
def allowed_token_ids(self, _):
return []

def is_final_state(self, state):
Expand Down Expand Up @@ -201,7 +201,7 @@ class MockFSM:
def next_state(self, state, next_token_ids):
return 0

def forbidden_token_ids(self, _):
def allowed_token_ids(self, _):
return []

def is_final_state(self, _):
Expand Down Expand Up @@ -254,7 +254,7 @@ class MockFSM:
def next_state(self, state, next_token_ids):
return FSMState(state + 1)

def forbidden_token_ids(self, _):
def allowed_token_ids(self, _):
return []

def is_final_state(self, state):
Expand Down Expand Up @@ -398,15 +398,15 @@ def next_state(self, state, next_token_ids):
assert result == [0, 0]


def test_get_forbidden_token_idss():
def test_get_allowed_token_idss():
class MockFSM:
def forbidden_token_ids(self, _):
def allowed_token_ids(self, _):
return [1, 2, 3, 4]

result = get_logits_masks(MockFSM(), [0])
result = get_allowed_tokens(MockFSM(), [0])
assert result == [[1, 2, 3, 4]]

result = get_logits_masks(MockFSM(), [0, 1])
result = get_allowed_tokens(MockFSM(), [0, 1])
assert result == [[1, 2, 3, 4], [1, 2, 3, 4]]


Expand Down

0 comments on commit 858a9cb

Please sign in to comment.