diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index bc1c51508..0cd91df50 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -11,7 +11,7 @@ class FSM(Protocol): - def forbidden_token_ids(self, state: FSMState) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: ... def next_state(self, state: FSMState, token_id: int) -> FSMState: @@ -42,8 +42,8 @@ def __init__( self.vocabulary = tokenizer.vocabulary.values() self.final_states = {1} - def forbidden_token_ids(self, state: FSMState) -> List[int]: - """Generate a list of forbidden tokens for the next step. + def allowed_token_ids(self, state: FSMState) -> List[int]: + """Generate a list of allowed tokens for the next step. When in the initial state we allow every token to be generated. In the final state the only allowed token is `stop_token_id`. @@ -59,13 +59,9 @@ def forbidden_token_ids(self, state: FSMState) -> List[int]: """ if state == 0: - return [] + return list(self.vocabulary) else: - return [ - token_id - for token_id in self.vocabulary - if token_id != self.stop_token_id - ] + return [self.stop_token_id] def next_state(self, state: FSMState, token_id: int) -> FSMState: """Update the state of the FSM. @@ -137,8 +133,8 @@ def __init__( self.vocabulary = tokenizer.vocabulary.values() self.end_token_id = tokenizer.eos_token_id - def forbidden_token_ids(self, state: FSMState) -> List[int]: - """Generate a list of forbidden tokens for the next step. + def allowed_token_ids(self, state: FSMState) -> List[int]: + """Generate a list of allowed tokens for the next step. The initialization of the FSM builds an index which maps FSM states to a map from authorized tokens to the state in which the FSM needs to move @@ -163,15 +159,9 @@ def forbidden_token_ids(self, state: FSMState) -> List[int]: next_tokens_to_end_states = self.states_to_token_maps.get(state) if next_tokens_to_end_states is None: - authorized_tokens = [self.end_token_id] + return [self.end_token_id] else: - authorized_tokens = list(next_tokens_to_end_states.keys()) - - forbidden_tokens = [ - token for token in self.vocabulary if token not in authorized_tokens - ] - - return list(forbidden_tokens) + return list(next_tokens_to_end_states.keys()) def next_state(self, state: FSMState, token_id: int) -> FSMState: """Update the state of the FSM. diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index 6dd57c139..c5e9fb6a3 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -78,7 +78,7 @@ def sequence_generator( """ token_ids, attention_masks, kv_cache = init_state while True: - logits_masks = get_logits_masks(fsm, fsm_states) + logits_masks = get_allowed_tokens(fsm, fsm_states) next_token_ids, kv_cache, logits = token_generator( token_ids, @@ -171,7 +171,7 @@ def get_next_fsm_states( ] -def get_logits_masks(fsm: "FSM", fsm_states: List[FSMState]) -> torch.Tensor: +def get_allowed_tokens(fsm: "FSM", fsm_states: List[FSMState]) -> torch.Tensor: """Get the new instructions for each sequence from the finite-state machine. Parameters @@ -183,10 +183,10 @@ def get_logits_masks(fsm: "FSM", fsm_states: List[FSMState]) -> torch.Tensor: Returns ------- - A nested list that contains the ids of the logits to bias. + A nested list that contains the ids of the logits to keep. """ - return [fsm.forbidden_token_ids(state) for state in fsm_states] + return [fsm.allowed_token_ids(state) for state in fsm_states] def is_generation_finished(fsm: "FSM", fsm_states: List[FSMState]) -> bool: @@ -306,5 +306,7 @@ def bias_logits( """ for i, ids in enumerate(ids_to_mask): - logits[i, ids] = -math.inf + mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device) + mask[ids] = 0 + logits[i] = logits[i] + mask return logits diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index f2db973bc..19b11daa1 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -10,8 +10,8 @@ class MockTokenizer: fsm = StopAtTokenFSM(MockTokenizer(), 2) - assert fsm.forbidden_token_ids(0) == [] - assert fsm.forbidden_token_ids(1) == [1] + assert fsm.allowed_token_ids(0) == [1, 2] + assert fsm.allowed_token_ids(1) == [2] assert fsm.next_state(0, 2) == 1 assert fsm.next_state(0, 1) == 0 assert fsm.is_final_state(0) is False @@ -46,7 +46,7 @@ def convert_token_to_string(self, token): fsm = RegexFSM(regex_str, tokenizer) assert fsm.states_to_token_maps == {0: {1: 1}} - assert fsm.forbidden_token_ids(state=0) == [2, 3] + assert fsm.allowed_token_ids(state=0) == [1] assert fsm.next_state(state=0, token_id=1) == 1 assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1 diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index 6d80d2454..a29b1e263 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -9,7 +9,7 @@ from outlines.generate.generator import ( bias_logits, expand_attention_masks, - get_logits_masks, + get_allowed_tokens, get_next_fsm_states, init_generator_state, is_generation_finished, @@ -24,7 +24,7 @@ class MockFSM: def next_state(self, state, next_token_ids): return 0 - def forbidden_token_ids(self, _): + def allowed_token_ids(self, _): return [] def is_final_state(self, _): @@ -73,7 +73,7 @@ class MockFSM: def next_state(self, state, next_token_ids): return 0 - def forbidden_token_ids(self, _): + def allowed_token_ids(self, _): return [] def is_final_state(self, _): @@ -108,7 +108,7 @@ class MockFSM: def next_state(self, state, next_token_ids): return 0 - def forbidden_token_ids(self, _): + def allowed_token_ids(self, _): return [] def is_final_state(self, _): @@ -151,7 +151,7 @@ class MockFSM: def next_state(self, state, next_token_ids): return FSMState(state + 1) - def forbidden_token_ids(self, _): + def allowed_token_ids(self, _): return [] def is_final_state(self, state): @@ -201,7 +201,7 @@ class MockFSM: def next_state(self, state, next_token_ids): return 0 - def forbidden_token_ids(self, _): + def allowed_token_ids(self, _): return [] def is_final_state(self, _): @@ -254,7 +254,7 @@ class MockFSM: def next_state(self, state, next_token_ids): return FSMState(state + 1) - def forbidden_token_ids(self, _): + def allowed_token_ids(self, _): return [] def is_final_state(self, state): @@ -398,15 +398,15 @@ def next_state(self, state, next_token_ids): assert result == [0, 0] -def test_get_forbidden_token_idss(): +def test_get_allowed_token_idss(): class MockFSM: - def forbidden_token_ids(self, _): + def allowed_token_ids(self, _): return [1, 2, 3, 4] - result = get_logits_masks(MockFSM(), [0]) + result = get_allowed_tokens(MockFSM(), [0]) assert result == [[1, 2, 3, 4]] - result = get_logits_masks(MockFSM(), [0, 1]) + result = get_allowed_tokens(MockFSM(), [0, 1]) assert result == [[1, 2, 3, 4], [1, 2, 3, 4]]