Skip to content

Commit

Permalink
Always allow special tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Nov 2, 2023
1 parent 4801e43 commit 0c38b7e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion llm_cfg/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def load_nfa(tokenizer=None, inc_parser=None, use_cache=True):

exceptions = {'COMMENT': '#.*|\'\'\'.*?\'\'\'|""".*?"""/is', '_NL': '(\r?\n[\t ]*)+', 'LONG_STRING': '\'\'\'.*?\'\'\'|""".*?"""/is', 'STRING': '[ubf]?r?(".*?"|\'.*?\')'}
# , '_TAB': '\t+'
nfa = TerminalsNFA(inc_parser.parser.terminals, vocab, exceptions=exceptions)
nfa = TerminalsNFA(inc_parser.parser.terminals, vocab, exceptions=exceptions, special_token_ids=[tokenizer.eos_token_id])
print(f'Time taken for creating NFA:', time.time() - start_time, flush=True)

pickle.dump(nfa, open(NFA_LOC, 'wb'))
Expand Down
2 changes: 1 addition & 1 deletion llm_cfg/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
PreTrainedTokenizer,
LogitsProcessorList
)
from core import filter_code, run_eval, fix_indents
from core import run_eval
import os
import torch
import argparse
Expand Down
20 changes: 12 additions & 8 deletions llm_cfg/terminals_nfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ class TerminalsNFA:
"""
We build an NFA that consists of DFAs for each terminal. We simulate the NFA by consuming the input string for each terminal DFA.
"""
def __init__(self, terminals: list, vocab, exceptions={}):
def __init__(self, terminals: list, vocab, exceptions={}, special_token_ids=[]):
self._terminals_to_dfa = {}
self._vocab = vocab
self.anything_else = interegular.fsm.anything_else # This is special character used for the DFAs
self.exceptions = []
self.exceptions = []
self.special_token_ids = special_token_ids

for terminal in terminals:
if terminal.name in exceptions:
Expand All @@ -42,13 +43,18 @@ def __init__(self, terminals: list, vocab, exceptions={}):

self._convert_lookup_from_list_to_mask() # convert to boolean tensor mask. This is useful for fast union operations

def _get_default_mask(self):
mask = torch.zeros(len(self._vocab), dtype=torch.bool)
for token_id in self.special_token_ids:
mask[token_id] = True
return mask

def _store_overapproximate_tokens(self, terminals: list[str], vocab):
for cur_terminal in terminals:
for dfa_state in self._terminals_to_dfa[cur_terminal].states:

# Initialize the overapproximate tokens for each dfa state
self._dfa_state_to_tokens[(cur_terminal, dfa_state)] = torch.zeros(len(self._vocab), dtype=torch.bool)
self._dfa_state_to_tokens[(cur_terminal, dfa_state)] = self._get_default_mask()

# self._dfa_state_and_next_terminal_to_tokens[(dfa_state, next_terminal)] = []
for token_idx, token in enumerate(vocab):
Expand Down Expand Up @@ -139,8 +145,6 @@ def _nfa_state(self, input_str):
"""
nfa_state = []
for (termianl, dfa) in self._terminals_to_dfa.items():
# print(termianl)
# print(termianl, input_str)
dfa_state = self._consume_input(dfa, input_str)
if dfa_state is not None:
nfa_state.append((termianl, dfa_state))
Expand All @@ -156,11 +160,11 @@ def _convert_lookup_from_list_to_mask(self):
def _lookup_next_tokens_for_dfa_state(self, cur_terminal, dfa_state, next_terminal) -> torch.Tensor:
tokens = self._dfa_state_and_next_terminal_to_tokens[(cur_terminal, dfa_state, next_terminal)]
if tokens == []:
return torch.zeros(len(self._vocab), dtype=torch.bool)
return self._get_default_mask()
return tokens

def _lookup_next_tokens(self, nfa_state, r: ParseResult) -> torch.Tensor:
overapprox_token_ids = torch.zeros(len(self._vocab), dtype=torch.bool)
overapprox_token_ids = self._get_default_mask()
# print('Time taken for NFA state:', time.time() - start_time, flush=True)

if r.remainder_state == RemainderState.COMPLETE:
Expand Down Expand Up @@ -204,7 +208,7 @@ def get_overapprox_tokens_mask(self, r: ParseResult, get_list=False):

def _get_tokens_mask(self, tokens_idx_list) -> torch.Tensor:
indices = torch.tensor(tokens_idx_list)
tokens_mask = torch.zeros(len(self._vocab), dtype=torch.bool)
tokens_mask = self._get_default_mask()
tokens_mask[indices] = 1
return tokens_mask

Expand Down
7 changes: 6 additions & 1 deletion llm_cfg/test_nfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def test_nfa8():
assert 'num' in ac_list
assert '()' in ac_list

def test_nfa9():
r = ParseResult({}, {}, '', RemainderState.MAYBE_COMPLETE)
ac_list = nfa.get_overapprox_tokens_mask(r, get_list=True)
assert '</s>' in ac_list # special token should always be in the list

def test_indetantaion():
from mxeval.data import get_data
mbpp = get_data("mbxp", "python")
Expand All @@ -68,5 +73,5 @@ def test_indetantaion():
assert p._get_indentation(mbpp['MBPP/2']["prompt"]) == 2
assert p._get_indentation(mbpp['MBPP/8']["prompt"]) == 1

tests = [test_nfa, test_nfa2, test_nfa3, test_nfa4, test_nfa5, test_nfa6, test_nfa7, test_nfa8, test_indetantaion]
tests = [test_nfa, test_nfa2, test_nfa3, test_nfa4, test_nfa5, test_nfa6, test_nfa7, test_nfa8, test_nfa9, test_indetantaion]
common.run_tests(tests)

0 comments on commit 0c38b7e

Please sign in to comment.