Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Oct 30, 2023
1 parent 6d7f691 commit d3bce83
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 69 deletions.
51 changes: 35 additions & 16 deletions llm_cfg/incremental_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,23 @@
from lark import Lark


class ParseResult:
"""
Stores the result of parsing.
"""
def __init__(self, cur_accept_terminals, next_accept_terminals, final_incomplete_str, is_terminal_complete):
self.final_incomplete_str = final_incomplete_str
self.is_terminal_complete = is_terminal_complete # Whether the final_string is a complete terminal
self.cur_accept_terminals = cur_accept_terminals
self.next_accept_terminals = next_accept_terminals

if not is_terminal_complete: # If the terminal is not complete, then next_accept_terminals should be None
assert next_accept_terminals is None

class IncrementalParser:
"""
This class implements an incremental parser for Python code.
"""
def __init__(self):
self.parser = Lark.open( # This is the standard Lark parser
"llm_cfg/python_grammar.lark",
Expand All @@ -32,7 +48,7 @@ def __init__(self):
self.prev_lexer_tokens = None
self.cur_pos_to_interactive = {}

def get_acceptable_next_terminals(self, code):
def get_acceptable_next_terminals(self, code) -> ParseResult:
# Stores the sequence of tokens that the parser has seen in the order
last_terminal_complete = True
interactive = self.interactive
Expand Down Expand Up @@ -105,6 +121,9 @@ def get_acceptable_next_terminals(self, code):
last_terminal_complete = False
current_term_str = code[self.lexer_pos:]
current_term_str = current_term_str.lstrip(' ') # Remove space from the beginning

if current_term_str == '':
last_terminal_complete = True
# print('current_term_str 1:', current_term_str)
else:
# Although this is a complete terminal, it may happen that this may be just prefix of some other terminal
Expand All @@ -114,44 +133,44 @@ def get_acceptable_next_terminals(self, code):

if last_terminal_complete:
if self.parser_token_seq[-1].type == '_NL':
next_ac_terminals = self.next_ac_terminals
# Compute next line accepted indentation levels
max_next_indentation_level = 0
# print('next_ac_terminals:', next_ac_terminals)

if '_INDENT' in next_ac_terminals:
if '_INDENT' in self.next_ac_terminals:
max_next_indentation_level = self.cur_indentation_level + 1
elif '_DEDENT' in next_ac_terminals and len(next_ac_terminals)==1:
elif '_DEDENT' in self.next_ac_terminals and len(self.next_ac_terminals)==1:
max_next_indentation_level = self.cur_indentation_level - 1
elif '_DEDENT' in next_ac_terminals and len(next_ac_terminals)>1:
elif '_DEDENT' in self.next_ac_terminals and len(self.next_ac_terminals)>1:
max_next_indentation_level = self.cur_indentation_level

cur_tabs = self.parser_token_seq[-1].value.split('\n')[-1].count('\t')

# Remove the _INDENT and _DEDENT tokens from the acceptable tokens
# since we inform the indentation level through the _TAB token
if '_INDENT' in next_ac_terminals:
next_ac_terminals.remove('_INDENT')
if '_DEDENT' in next_ac_terminals:
next_ac_terminals.remove('_DEDENT')
if '_INDENT' in self.next_ac_terminals:
self.next_ac_terminals.remove('_INDENT')
if '_DEDENT' in self.next_ac_terminals:
self.next_ac_terminals.remove('_DEDENT')

# '_NL' is always accepted in this case
next_ac_terminals.add('_NL')
self.next_ac_terminals.add('_NL')

if cur_tabs < max_next_indentation_level:
# print('Expect a tab!')
next_ac_terminals.add('_TAB')
self.next_ac_terminals.add('_TAB')
# elif cur_tabs > max_next_indentation_level:
# raise Exception('Invalid indentation level! max_next_indentation_level: {}, cur_tabs: {}'.format(max_next_indentation_level, cur_tabs))
else:
# Since current terminal is incomplete, next token should add to current terminal
next_ac_terminals = None

else: # Since current terminal is incomplete, next token should add to current terminal
self.next_ac_terminals = None

if self.next_ac_terminals is not None and '_NL' in self.next_ac_terminals:
self.next_ac_terminals.add('COMMENT')

return self.cur_ac_terminals, self.next_ac_terminals, current_term_str

return ParseResult(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, last_terminal_complete)


def _store_parser_state(self, pos, parser_state, indentation_level, accepts):
# print('storing state at position:', pos, len(self.interactive.parser_state.state_stack), len(self.dedent_queue))
dedent_queue = copy.deepcopy(self.dedent_queue)
Expand Down
20 changes: 10 additions & 10 deletions llm_cfg/python_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import common
from transformers import LogitsProcessor, PreTrainedTokenizer
from incremental_parser import IncrementalParser
from incremental_parser import IncrementalParser, ParseResult


class PythonDecoder(LogitsProcessor):
Expand Down Expand Up @@ -34,9 +34,9 @@ def __init__(self, tokenizer: PreTrainedTokenizer, **kwargs):

print(f"Time taken for preprocessing: {time.time() - time_start:.2f}s")

def _print_current_status(self, partial_code, cur_ac_terminals, next_ac_terminals, incomplete_terminal):
def _print_current_status(self, partial_code, r: ParseResult):
print('partial code:\n', repr(partial_code))
print('inc:', repr(incomplete_terminal), '\n', 'cur:', cur_ac_terminals, '\n', 'next:', next_ac_terminals)
print('inc:', repr(r.final_incomplete_str), '\n', 'cur:', r.cur_accept_terminals, '\n', 'next:', r.next_accept_terminals)


def _reset(self):
Expand All @@ -63,27 +63,27 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
self.partial_codes_trace.append(partial_code)

# returns the names of the Terminals that are currently accepted.
cur_ac_terminals, next_ac_terminals, incomplete_terminal = self.inc_parsers[i].get_acceptable_next_terminals(partial_code)
r = self.inc_parsers[i].get_acceptable_next_terminals(partial_code)

greedy_token = self.tokenizer.decode(scores.argmax(dim=-1), skip_special_tokens=True) # For debugging - remove later

if 'EOF' in next_ac_terminals:
if 'EOF' in r.next_accept_terminals:
self.last_valid_state[i] = len(input_ids[i])

self.accept_tokens_sizes.append(len(cur_ac_terminals)) # For profiling
self.accept_tokens_sizes.append(len(r.cur_accept_terminals)) # For profiling

print(i, 'Time taken for compilation:', time.time() - compilation_start_time)
accept_mask = self.terminals_nfa.get_overapprox_tokens_mask(incomplete_terminal, next_ac_terminals)
accept_mask = self.terminals_nfa.get_overapprox_tokens_mask(r.final_incomplete_str, r.next_accept_terminals)

print(i, 'Time taken for overapproximation:', time.time() - compilation_start_time)
if self.debug and self.token_cnt%50==0:
self._print_current_status(partial_code, cur_ac_terminals, next_ac_terminals, incomplete_terminal)
self._print_current_status(partial_code, r)

if torch.sum(accept_mask) != 0: # If there are acceptable tokens for the current partial code
scores[i] = scores[i].masked_fill(~accept_mask.to(scores.device), -float("inf"))
else: # Otherwise, report the error and mask no tokens
print('No acceptable tokens for the current partial code!')
self._print_current_status(partial_code, cur_ac_terminals, next_ac_terminals, incomplete_terminal)
self._print_current_status(partial_code, r)

print(i, 'Time taken for masking:', time.time() - compilation_start_time)

Expand All @@ -92,7 +92,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
if greedy_token != greedy_grammar_token:
print('Greedy token:', repr(greedy_token), scores.argmax(dim=-1))
print('Greedy grammar-based token:', repr(greedy_grammar_token), scores.argmax(dim=-1))
self._print_current_status(partial_code, cur_ac_terminals, next_ac_terminals, incomplete_terminal)
self._print_current_status(partial_code, r)
self.non_matching_token_cnt += 1
except Exception as e:
print("-"*80, '\n', 'Code lenght:', len(partial_code), '\n', partial_code, '\n', repr(partial_code), '\n', 'Error:', e)
Expand Down
Loading

0 comments on commit d3bce83

Please sign in to comment.