Skip to content

Commit

Permalink
Refactoring and some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Nov 1, 2023
1 parent 5cd6364 commit 3c0ddbc
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 88 deletions.
1 change: 1 addition & 0 deletions llm_cfg/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def load_nfa(tokenizer=None, inc_parser=None, use_cache=True):
print('Time taken for loading parser:', time.time() - start_time, flush=True)

exceptions = {'COMMENT': '#.*|\'\'\'.*?\'\'\'|""".*?"""/is', '_NL': '(\r?\n[\t ]*)+', 'LONG_STRING': '\'\'\'.*?\'\'\'|""".*?"""/is', 'STRING': '[ubf]?r?(".*?"|\'.*?\')'}
# , '_TAB': '\t+'
nfa = TerminalsNFA(inc_parser.parser.terminals, vocab, exceptions=exceptions)
print(f'Time taken for creating NFA:', time.time() - start_time, flush=True)

Expand Down
100 changes: 45 additions & 55 deletions llm_cfg/incremental_parser.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,30 @@
import copy
import time
import re
import regex
import lark
import lmql_regex
from parse_result import RemainderState, ParseResult
from lark.indenter import Indenter
from lark.lexer import Token
from lark import Lark


class ParseResult:
"""
Stores the result of parsing.
"""
def __init__(self, cur_accept_terminals, next_accept_terminals, final_incomplete_str, is_terminal_complete):
self.final_incomplete_str = final_incomplete_str
self.is_terminal_complete = is_terminal_complete # Whether the final_string is a complete terminal
self.cur_accept_terminals = cur_accept_terminals
self.next_accept_terminals = next_accept_terminals

if not is_terminal_complete: # If the terminal is not complete, then next_accept_terminals should be None
assert next_accept_terminals is None

class IncrementalParser:
"""
This class implements an incremental parser for Python code.
"""
def __init__(self):
def __init__(self, partial_code=None):
indenter = PythonIndenter()

if partial_code is not None: # extract indentation type from partial code
indenter.tab_len = self._get_indentation(partial_code) # NOTE: tab_len is useful when \t and spaces are used for indentation in same code

self.parser = Lark.open( # This is the standard Lark parser
"llm_cfg/python_grammar.lark",
parser="lalr",
lexer="basic",
start="file_input",
postlex=PythonIndenter(),
postlex=indenter,
propagate_positions=True,
)
self.cur_ac_terminals = None
Expand All @@ -48,32 +41,37 @@ def __init__(self):
self.prev_lexer_tokens = None
self.cur_pos_to_interactive = {}

def _get_indentation(self, partial_code) -> int:
m = regex.match(r"(.*?):(.*?)\n(.*?)(?![ \t])", partial_code, flags=regex.DOTALL)
indent_type = m.group(3)
tab_len = 4 # Default tab length
if '\t' not in indent_type: # that means we are using spaces for indentation
tab_len = indent_type.count(' ')
return tab_len

def get_acceptable_next_terminals(self, code) -> ParseResult:
# Stores the sequence of tokens that the parser has seen in the order
last_terminal_complete = True
interactive = self.interactive
lexer_tokens = self._lex_code(code)

# Restore the previous state of the parser
if self.prev_lexer_tokens is not None:
i = 0

while i < min(len(self.prev_lexer_tokens), len(lexer_tokens)) and lexer_tokens[i] == self.prev_lexer_tokens[i]:
i += 1

self.cur_pos = i
# print('********Restoring parser state 1!', self.cur_pos-1)
# print(self.prev_lexer_tokens[self.cur_pos-1], lexer_tokens[self.cur_pos-1])
# print(self.cur_pos_to_interactive.keys())
# print(len(self.prev_lexer_tokens), len(lexer_tokens))
# print(self.prev_lexer_tokens)
# print(lexer_tokens)

if (self.cur_pos-1) in self.cur_pos_to_interactive:
# print('*******Restoring parser state 2!', self.cur_pos-1)
# print(self.cur_pos_to_interactive[self.cur_pos-1][0].state_stack, len(self.cur_pos_to_interactive[self.cur_pos-1][0].state_stack), len(self.dedent_queue))
self._restore_parser_state(self.cur_pos-1)

# Find the maximum index such that the tokens are same and the parser state is stored
max_matching_index = -1
for i in range(min(len(self.prev_lexer_tokens), len(lexer_tokens))):
if self.prev_lexer_tokens[i] != lexer_tokens[i]:
break
if i in self.cur_pos_to_interactive:
max_matching_index = i

if max_matching_index != -1:
self.cur_pos = max_matching_index + 1
# print('********Restoring parser state 1!', max_matching_index )
# print(self.prev_lexer_tokens[self.cur_pos-1], lexer_tokens[self.cur_pos-1])
assert (max_matching_index) in self.cur_pos_to_interactive
self._restore_parser_state(max_matching_index)

# Set the previous lexer tokens
self.prev_lexer_tokens = lexer_tokens

# Parse the tokens
Expand Down Expand Up @@ -104,32 +102,30 @@ def get_acceptable_next_terminals(self, code) -> ParseResult:
self._store_parser_state(self.cur_pos-1, interactive.parser_state.copy(), self.cur_indentation_level, interactive.accepts())

except lark.exceptions.UnexpectedToken as e:
# print(e)
pass

# Print the store
# print('JUST PRINTING THE STORED STATES!')
# for pos in self.cur_pos_to_interactive.keys():
# print(pos, len(self.cur_pos_to_interactive[pos][0].state_stack), len(self.cur_pos_to_interactive[pos][3]))

if self.log_time:
print('Time taken for parsing:', (time.time() - parsing_start_time))

reminder_state = None
# Compute current terminal string
if self.lexer_pos < len(code):
last_terminal_complete = False
reminder_state = RemainderState.INCOMPLETE
current_term_str = code[self.lexer_pos:]
# print('current_term_str 1:', current_term_str)
# print('current_term_str 1:', repr(current_term_str))

current_term_str = current_term_str.lstrip(' ') # Remove space from the beginning
if current_term_str == '':
last_terminal_complete = True
reminder_state = RemainderState.COMPLETE
else:
# Although this is a complete terminal, it may happen that this may be just prefix of some other terminal
# e.g., 'de' may seem like a variable name that is complete, but it may be just a prefix of 'def'
current_term_str = self.parser_token_seq[-1].value
reminder_state = RemainderState.MAYBE_COMPLETE
# print('current_term_str 2:', current_term_str, self.parser_token_seq)

if last_terminal_complete:
if reminder_state == RemainderState.MAYBE_COMPLETE or reminder_state == RemainderState.COMPLETE:
if self.parser_token_seq[-1].type == '_NL':
# Compute next line accepted indentation levels
max_next_indentation_level = 0
Expand Down Expand Up @@ -166,7 +162,7 @@ def get_acceptable_next_terminals(self, code) -> ParseResult:
if self.next_ac_terminals is not None and '_NL' in self.next_ac_terminals:
self.next_ac_terminals.add('COMMENT')

return ParseResult(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, last_terminal_complete)
return ParseResult(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, reminder_state)


def _store_parser_state(self, pos, parser_state, indentation_level, accepts):
Expand Down Expand Up @@ -200,26 +196,23 @@ def get_matching_terminal(self, s):
# TODO: Use priorities to resolve conflicts
return None

def _lex_code(self, code):
def _lex_code(self, code: str) -> list:
# Collect Lexer tokens
lexer_tokens = []
interactive = self.parser.parse_interactive(code)
# interactive = self.interactive
lexing_start_time = time.time()
lexer_state = interactive.lexer_thread.state
indenter = self.parser.lexer_conf.postlex
indenter: Indenter = self.parser.lexer_conf.postlex

# Reset the indentation level
indenter.indent_level = [0]
indenter.paren_level = 0
indenter.indent_level, indenter.paren_level = [0], 0
# print('Starting indent level:', indenter.indent_level)

try:
while lexer_state.line_ctr.char_pos < len(lexer_state.text):
blexer = interactive.lexer_thread.lexer.lexer
token = blexer.next_token(lexer_state)
self.lexer_pos = lexer_state.line_ctr.char_pos

# Perform postlexing indentation
if token.type == indenter.NL_type:
# print('NL token:', indenter.indent_level)
Expand All @@ -240,10 +233,7 @@ def _lex_code(self, code):
# print(lexer_state.line_ctr.char_pos, len(lexer_state.text))
pass
# raise e
# Add the remaining dedent tokens at the end
# while len(indenter.indent_level) > 1:
# indenter.indent_level.pop()
# lexer_tokens.append(Token(indenter.DEDENT_type, ''))

if self.log_time:
print('Time taken for lexing:', time.time() - lexing_start_time)
# print(lexer_tokens)
Expand Down
25 changes: 25 additions & 0 deletions llm_cfg/parse_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from enum import Enum

class RemainderState(Enum):
"""
The state of the reminder after parsing partial code.
"""
COMPLETE = 0
MAYBE_COMPLETE = 1
INCOMPLETE = 2

class ParseResult:
"""
Stores the result of parsing.
"""
def __init__(self, cur_accept_terminals, next_accept_terminals, remainder, remainder_state: RemainderState):
self.remainder = remainder
self.remainder_state = remainder_state # Whether the final_string is a complete terminal
self.cur_accept_terminals = cur_accept_terminals
self.next_accept_terminals = next_accept_terminals

if remainder_state == RemainderState.INCOMPLETE: # If the terminal is not complete, then next_accept_terminals should be None
assert next_accept_terminals is None

def __repr__(self):
return 'final_incomplete_str: {}\nis_terminal_complete: {}\ncur_accept_terminals: {}\nnext_accept_terminals: {}'.format(repr(self.remainder), self.remainder_state, self.cur_accept_terminals, self.next_accept_terminals)
5 changes: 3 additions & 2 deletions llm_cfg/python_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, tokenizer: PreTrainedTokenizer, **kwargs):

def _print_current_status(self, partial_code, r: ParseResult):
print('partial code:\n', repr(partial_code))
print('inc:', repr(r.final_incomplete_str), '\n', 'cur:', r.cur_accept_terminals, '\n', 'next:', r.next_accept_terminals)
print('inc:', repr(r.remainder), '\n', 'cur:', r.cur_accept_terminals, '\n', 'next:', r.next_accept_terminals)


def _reset(self):
Expand Down Expand Up @@ -76,7 +76,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
accept_mask = self.terminals_nfa.get_overapprox_tokens_mask(r)

print(i, 'Time taken for overapproximation:', time.time() - compilation_start_time)
if self.debug and self.token_cnt%50==0:
if self.debug:
# print(scores[i][:20])
self._print_current_status(partial_code, r)

if torch.sum(accept_mask) != 0: # If there are acceptable tokens for the current partial code
Expand Down
21 changes: 14 additions & 7 deletions llm_cfg/terminals_nfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import interegular
import torch
import regex
from incremental_parser import ParseResult
from parse_result import RemainderState, ParseResult

class Exception:
"""
Expand Down Expand Up @@ -159,15 +159,22 @@ def _lookup_next_tokens_for_dfa_state(self, cur_terminal, dfa_state, next_termin
return torch.zeros(len(self._vocab), dtype=torch.bool)
return tokens

def _lookup_next_tokens(self, nfa_state, next_terminals: list) -> torch.Tensor:
def _lookup_next_tokens(self, nfa_state, r: ParseResult) -> torch.Tensor:
overapprox_token_ids = torch.zeros(len(self._vocab), dtype=torch.bool)
# print('Time taken for NFA state:', time.time() - start_time, flush=True)


if r.remainder_state == RemainderState.COMPLETE:
for (terminal, dfa_state) in nfa_state:
if terminal in r.next_accept_terminals:
overapprox_token_ids |= self._dfa_state_to_tokens[(terminal, dfa_state)]
return overapprox_token_ids

# Case when the final string may be incomplete
for (cur_terminal, dfa_state) in nfa_state:
if next_terminals == None: # This is the case when we have incomplete final string
if r.next_accept_terminals == None: # This is the case when we have incomplete final string
overapprox_token_ids |= self._dfa_state_to_tokens[(cur_terminal, dfa_state)]
else:
for next_terminal in next_terminals:
for next_terminal in r.next_accept_terminals:
overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(cur_terminal, dfa_state, next_terminal)
return overapprox_token_ids

Expand All @@ -179,15 +186,15 @@ def _exception_rule(self, s, exceptions: list[Exception]) -> str:

def get_overapprox_tokens_mask(self, r: ParseResult, get_list=False):
# start_time = time.time()
cur_incomplete_string = self._exception_rule(r.final_incomplete_str, self.exceptions)
cur_incomplete_string = self._exception_rule(r.remainder, self.exceptions)
# print(cur_incomplete_string)
if cur_incomplete_string is None:
return torch.ones(len(self._vocab), dtype=torch.bool)

cur_nfa_state = self._nfa_state(cur_incomplete_string)
print(cur_nfa_state)

overapprox_token_ids = self._lookup_next_tokens(cur_nfa_state, r.next_accept_terminals)
overapprox_token_ids = self._lookup_next_tokens(cur_nfa_state, r)

# print('Time taken for union:', time.time() - start_time, flush=True)
if get_list: # This is useful for testing
Expand Down
Loading

0 comments on commit 3c0ddbc

Please sign in to comment.