diff --git a/llm_cfg/common.py b/llm_cfg/common.py index a3850eda..35f5bbec 100644 --- a/llm_cfg/common.py +++ b/llm_cfg/common.py @@ -43,7 +43,7 @@ def load_nfa(tokenizer=None, inc_parser=None, use_cache=True): exceptions = {'COMMENT': '#.*|\'\'\'.*?\'\'\'|""".*?"""/is', '_NL': '(\r?\n[\t ]*)+', 'LONG_STRING': '\'\'\'.*?\'\'\'|""".*?"""/is', 'STRING': '[ubf]?r?(".*?"|\'.*?\')'} # , '_TAB': '\t+' - nfa = TerminalsNFA(inc_parser.parser.terminals, vocab, exceptions=exceptions) + nfa = TerminalsNFA(inc_parser.parser.terminals, vocab, exceptions=exceptions, special_token_ids=[tokenizer.eos_token_id]) print(f'Time taken for creating NFA:', time.time() - start_time, flush=True) pickle.dump(nfa, open(NFA_LOC, 'wb')) diff --git a/llm_cfg/infer.py b/llm_cfg/infer.py index a62912a0..b1409fba 100644 --- a/llm_cfg/infer.py +++ b/llm_cfg/infer.py @@ -7,7 +7,7 @@ PreTrainedTokenizer, LogitsProcessorList ) -from core import filter_code, run_eval, fix_indents +from core import run_eval import os import torch import argparse diff --git a/llm_cfg/terminals_nfa.py b/llm_cfg/terminals_nfa.py index ccc4395c..229c5404 100644 --- a/llm_cfg/terminals_nfa.py +++ b/llm_cfg/terminals_nfa.py @@ -21,11 +21,12 @@ class TerminalsNFA: """ We build an NFA that consists of DFAs for each terminal. We simulate the NFA by consuming the input string for each terminal DFA. """ - def __init__(self, terminals: list, vocab, exceptions={}): + def __init__(self, terminals: list, vocab, exceptions={}, special_token_ids=[]): self._terminals_to_dfa = {} self._vocab = vocab self.anything_else = interegular.fsm.anything_else # This is special character used for the DFAs - self.exceptions = [] + self.exceptions = [] + self.special_token_ids = special_token_ids for terminal in terminals: if terminal.name in exceptions: @@ -42,13 +43,18 @@ def __init__(self, terminals: list, vocab, exceptions={}): self._convert_lookup_from_list_to_mask() # convert to boolean tensor mask. This is useful for fast union operations + def _get_default_mask(self): + mask = torch.zeros(len(self._vocab), dtype=torch.bool) + for token_id in self.special_token_ids: + mask[token_id] = True + return mask def _store_overapproximate_tokens(self, terminals: list[str], vocab): for cur_terminal in terminals: for dfa_state in self._terminals_to_dfa[cur_terminal].states: # Initialize the overapproximate tokens for each dfa state - self._dfa_state_to_tokens[(cur_terminal, dfa_state)] = torch.zeros(len(self._vocab), dtype=torch.bool) + self._dfa_state_to_tokens[(cur_terminal, dfa_state)] = self._get_default_mask() # self._dfa_state_and_next_terminal_to_tokens[(dfa_state, next_terminal)] = [] for token_idx, token in enumerate(vocab): @@ -139,8 +145,6 @@ def _nfa_state(self, input_str): """ nfa_state = [] for (termianl, dfa) in self._terminals_to_dfa.items(): - # print(termianl) - # print(termianl, input_str) dfa_state = self._consume_input(dfa, input_str) if dfa_state is not None: nfa_state.append((termianl, dfa_state)) @@ -156,11 +160,11 @@ def _convert_lookup_from_list_to_mask(self): def _lookup_next_tokens_for_dfa_state(self, cur_terminal, dfa_state, next_terminal) -> torch.Tensor: tokens = self._dfa_state_and_next_terminal_to_tokens[(cur_terminal, dfa_state, next_terminal)] if tokens == []: - return torch.zeros(len(self._vocab), dtype=torch.bool) + return self._get_default_mask() return tokens def _lookup_next_tokens(self, nfa_state, r: ParseResult) -> torch.Tensor: - overapprox_token_ids = torch.zeros(len(self._vocab), dtype=torch.bool) + overapprox_token_ids = self._get_default_mask() # print('Time taken for NFA state:', time.time() - start_time, flush=True) if r.remainder_state == RemainderState.COMPLETE: @@ -204,7 +208,7 @@ def get_overapprox_tokens_mask(self, r: ParseResult, get_list=False): def _get_tokens_mask(self, tokens_idx_list) -> torch.Tensor: indices = torch.tensor(tokens_idx_list) - tokens_mask = torch.zeros(len(self._vocab), dtype=torch.bool) + tokens_mask = self._get_default_mask() tokens_mask[indices] = 1 return tokens_mask diff --git a/llm_cfg/test_nfa.py b/llm_cfg/test_nfa.py index 981fb7ec..a7268206 100644 --- a/llm_cfg/test_nfa.py +++ b/llm_cfg/test_nfa.py @@ -60,6 +60,11 @@ def test_nfa8(): assert 'num' in ac_list assert '()' in ac_list +def test_nfa9(): + r = ParseResult({}, {}, '', RemainderState.MAYBE_COMPLETE) + ac_list = nfa.get_overapprox_tokens_mask(r, get_list=True) + assert '' in ac_list # special token should always be in the list + def test_indetantaion(): from mxeval.data import get_data mbpp = get_data("mbxp", "python") @@ -68,5 +73,5 @@ def test_indetantaion(): assert p._get_indentation(mbpp['MBPP/2']["prompt"]) == 2 assert p._get_indentation(mbpp['MBPP/8']["prompt"]) == 1 -tests = [test_nfa, test_nfa2, test_nfa3, test_nfa4, test_nfa5, test_nfa6, test_nfa7, test_nfa8, test_indetantaion] +tests = [test_nfa, test_nfa2, test_nfa3, test_nfa4, test_nfa5, test_nfa6, test_nfa7, test_nfa8, test_nfa9, test_indetantaion] common.run_tests(tests)