Skip to content

Commit

Permalink
Fixes to parser and explicitly print parsing error
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Dec 25, 2024
1 parent a034a36 commit a88b744
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 42 deletions.
131 changes: 131 additions & 0 deletions notebooks/tests/builtin_grammar.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"from syncode import SyncodeLogitsProcessor\n",
"from syncode import Grammar\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"import os\n",
"\n",
"HF_CACHE = os.environ['HF_CACHE'] if 'HF_CACHE' in os.environ else 'cache/'\n",
"HF_ACCESS_TOKEN = os.environ['HF_ACCESS_TOKEN'] if 'HF_ACCESS_TOKEN' in os.environ else None\n",
"\n",
"device = 'cuda'\n",
"model_name = \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"# model_name = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Write a java function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
" \n",
"\n",
"[OUTPUT] public class HelloWorld {\n",
" public static void main(String[] args) {\n",
" System.out.println(\"Hello World\");\n",
" } \n",
"\n",
" public static void printReverse(String str) {\n",
" char[] arr = str.toCharArray();\n",
" int start = 0;\n",
" int end = arr.length - 1;\n",
"\n",
" while (start < end) {\n",
" System.out.print(arr[start]);\n",
" System.out.print(arr[end]);\n",
" start++;\n",
" end--;\n",
" } \n",
" System.out.println();\n",
" } \n",
"}\n"
]
}
],
"source": [
"# grammar_str = \"python\"\n",
"# grammar_str = \"go\"\n",
"grammar_str = \"java\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
"\n",
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"prompt = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
"\n",
"syncode_logits_processor.reset(prompt)\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
"\n",
"attention_mask = torch.ones_like(inputs)\n",
"output = model.generate(\n",
" inputs,\n",
" attention_mask=attention_mask,\n",
" max_length=512, \n",
" num_return_sequences=1, \n",
" pad_token_id=tokenizer.eos_token_id, \n",
" logits_processor=[syncode_logits_processor]\n",
" )\n",
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "codex",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
30 changes: 18 additions & 12 deletions syncode/evaluation/code_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def run_code_eval(
samples = []
outputs = []

assert syncode.mode == 'original' or syncode.parse_output_only == False, "The SynCode flag parse_output_only should be False for code evaluation with grammar mode"

if syncode.language == "python":
stop_words = ["\n\n\n"]
elif syncode.language == "go":
Expand All @@ -49,7 +51,7 @@ def run_code_eval(
logger.log(f"Functional result: {functional_result}")

# Also log these results in a separate file
CodeEval.write_results(syncode, out_path, avg_time, functional_result)
CodeEval.write_results(syncode, out_path, avg_time, functional_result, num_tasks)
else: # Debugging a specific task
debug_task_id = list(problems.keys())[debug_task_id]
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger, stop_words=stop_words)
Expand Down Expand Up @@ -100,14 +102,15 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp
torch.cuda.empty_cache()
return all_completions

def write_results(self, out_path, avg_time, functional_result):
def write_results(syncode, out_path, avg_time, functional_result, num_tasks=1):
"""
Write results to a separate file
"""
file_path = "results/syncode_results.txt"
os.makedirs("results", exist_ok=True)
with open(file_path, "a") as f:
f.write(f"{self.model_name} | {self.grammar} | {self.dataset} | {self.parser} | {self.num_samples} | {self.mode}\n")
f.write(f"{syncode.model_name} | {syncode.grammar} | {syncode.dataset} | {syncode.parser} | {syncode.num_samples} | {syncode.mode} | num tasks: {num_tasks}\n")
f.write(f"Generation args: {syncode.model.gen_args}\n")
f.write(f"Functional result: {functional_result}\n")
f.write(f"Output path: {out_path}\n")
f.write(f"Averge time taken for each task: {avg_time:.2f}s\n")
Expand Down Expand Up @@ -136,15 +139,18 @@ def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_
return completion

def compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion):
fn_ends = sorted(list(set(grammar_decoder.function_ends[i])))
if grammar_decoder.function_ends[i] is not None and len(fn_ends) > 1:
# if the function end is not None, then the last valid state is the function end
last_valid_state = fn_ends[1]
else:
# otherwise, the last valid state is the last valid state
function_incomplete[i] = True
last_valid_state = grammar_decoder.last_valid_state[i]
if grammar_decoder.function_ends[i] is not None:
fn_ends = sorted(list(set(grammar_decoder.function_ends[i])))
if len(fn_ends) > 1:
# if the function end is not None, then the last valid state is the function end
last_valid_state = fn_ends[1]
return raw_completion[:last_valid_state]

# otherwise, the last valid state is the last valid state
function_incomplete[i] = True
last_valid_state = grammar_decoder.last_valid_state[i]

# Use when the stop word does not exist in the completion
backup_completion = raw_completion[:last_valid_state]
return backup_completion
return backup_completion

10 changes: 8 additions & 2 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from syncode.parsers.grammars import Grammar

# Set to True for debugging
DEBUG = False
DEBUG = True

class SyncodeLogitsProcessor(LogitsProcessor):
"""
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self,
self.logger = logger
self.dev_mode = dev_mode
self.batch_size = num_samples
self.parse_failed = False

# For backtracking to syntactically valid completions
self.last_valid_state: list = []
Expand Down Expand Up @@ -90,6 +91,7 @@ def reset(self, prompt: str):
"""
self.last_valid_state = [0 for _ in range(self.batch_size)]
self.function_ends = [None for _ in range(self.batch_size)]
self.parse_failed = False

prompt_tokens = self.tokenizer.encode(prompt, return_tensors='pt')[0]
if self.parse_output_only:
Expand Down Expand Up @@ -149,7 +151,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
except Exception as e:
if self.dev_mode == True:
raise e
self.logger.log(f"Exception while parsing:\n {e}")
elif self.parse_failed == False:
self.parse_failed = True
print("-"*50)
print(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_code}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
print("-"*50)
continue # Skip altering the scores for this batch

accept_mask = self.dfa_mask_store.get_accept_mask(r, logger=self.logger)
Expand Down
2 changes: 1 addition & 1 deletion syncode/parsers/go_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
self._accepts(interactive)
)
except lark.exceptions.UnexpectedToken as e:
self._handle_parsing_error(lexer_tokens, token)
self._handle_parsing_error(lexer_tokens, token, e)
parse_incomplete = True

# Compute current terminal string
Expand Down
8 changes: 5 additions & 3 deletions syncode/parsers/incremental_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:

except lark.exceptions.UnexpectedToken as e:
parse_incomplete = True
self._handle_parsing_error(lexer_tokens, token)
self._handle_parsing_error(lexer_tokens, token, e)

# Compute current terminal string
remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete)
Expand All @@ -175,10 +175,12 @@ def _get_remainder(self, code, lexing_incomplete=False, parse_incomplete=False):
remainder_state = RemainderState.INCOMPLETE
self.cur_ac_terminals = self.next_ac_terminals
self.next_ac_terminals = set()

elif parse_incomplete: # Parsing is incomplete
remainder_state = RemainderState.INCOMPLETE
current_term_str = self.parsed_lexer_tokens[-1].value
final_terminal = self.parsed_lexer_tokens[-1].type

elif len(self.parsed_lexer_tokens) > 0:
if self.lexer_pos < len(code): # In this case the final lexical tokens are ignored by the parser
remainder_state = RemainderState.COMPLETE
Expand All @@ -199,14 +201,14 @@ def _accepts(self, interactive_parser: InteractiveParser) -> set:
accepts = interactive_parser.accepts()
return accepts

def _handle_parsing_error(self, lexer_tokens, token):
def _handle_parsing_error(self, lexer_tokens, token, error):
"""
Handles the error that occurs when the lexer token is not parsed correctly.
1. If the final token is not parsed correctly, then it is okay.
2. If a non-final token is not parsed correctly, then it is an issue. We log the warning in that case.
"""
if token != lexer_tokens[-1]:
self.logger.log_error(f'Error in parsing the token: {token} which is not the last token in the lexer_tokens: {lexer_tokens}')
raise error
else:
# If it is the final token that gave the error, then it is okay
self.cur_ac_terminals = self.next_ac_terminals
Expand Down
33 changes: 11 additions & 22 deletions syncode/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def _get_indentation(self, partial_code) -> int:
tab_len = indent_type.count(' ')
return tab_len

def get_acceptable_next_terminals(self, code) -> ParseResult:
def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
# Stores the sequence of tokens that the parser has seen in the order
interactive = self.interactive
lexer_tokens = self._lex_code(code)
lexer_tokens, lexing_incomplete = self._lex_code(partial_code)

# Restore the previous state of the parser
self._restore_recent_parser_state(lexer_tokens)
Expand All @@ -44,6 +44,7 @@ def get_acceptable_next_terminals(self, code) -> ParseResult:

# Parse the tokens
self.time_accepts = 0
parse_incomplete = False

try:
while self.cur_pos < len(lexer_tokens):
Expand Down Expand Up @@ -75,27 +76,17 @@ def get_acceptable_next_terminals(self, code) -> ParseResult:
indent_levels=copy.copy(self.indent_level)
)
except lark.exceptions.UnexpectedToken as e:
self._handle_parsing_error(lexer_tokens, token)
parse_incomplete = True
self._handle_parsing_error(lexer_tokens, token, e)

remainder_state, final_terminal = None, None

# Compute current terminal string
if self.lexer_pos < len(code):
remainder_state = RemainderState.INCOMPLETE
current_term_str = code[self.lexer_pos:]

current_term_str = current_term_str.lstrip(' ') # Remove space from the beginning
if current_term_str == '':
remainder_state = RemainderState.COMPLETE
else:
# Although this is a complete terminal, it may happen that this may be just prefix of some other terminal
# e.g., 'de' may seem like a variable name that is complete, but it may be just a prefix of 'def'
current_term_str = self.parsed_lexer_tokens[-1].value
remainder_state = RemainderState.MAYBE_COMPLETE
final_terminal = self.parsed_lexer_tokens[-1].type
remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete)

next_ac_indents = None
if remainder_state == RemainderState.MAYBE_COMPLETE or remainder_state == RemainderState.COMPLETE:
if self.parsed_lexer_tokens[-1].type == '_NL':
if len(self.parsed_lexer_tokens) > 0 and self.parsed_lexer_tokens[-1].type == '_NL':
last_indent_str = self.parsed_lexer_tokens[-1].value.split('\n')[-1]
last_indent = last_indent_str.count(' ') + last_indent_str.count('\t') * self.tab_len
next_ac_indents = [indent-last_indent for indent in self.indent_level if indent >= last_indent]
Expand All @@ -111,10 +102,6 @@ def get_acceptable_next_terminals(self, code) -> ParseResult:
self.cur_ac_terminals.add('_NL')
self.next_ac_terminals.add('_NL')

else: # Since current terminal is incomplete, next token should add to current terminal
self.cur_ac_terminals = self.next_ac_terminals
self.next_ac_terminals = set()

return ParseResult.from_accept_terminals(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, remainder_state, next_ac_indents=next_ac_indents, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore)

def _update_indent_levels(self, indent_level, indent):
Expand All @@ -131,6 +118,7 @@ def _lex_code(self, code: str) -> Iterable[Token]:
interactive = self.base_parser.parse_interactive(code)
lexer_state = interactive.lexer_thread.state
indenter: PythonIndenter = self.base_parser.lexer_conf.postlex
lexing_incomplete = False

# Reset the indentation level
indenter.indent_level, indenter.paren_level = [0], 0
Expand All @@ -154,11 +142,12 @@ def _lex_code(self, code: str) -> Iterable[Token]:
indenter.paren_level -= 1
assert indenter.paren_level >= 0
except lark.exceptions.UnexpectedCharacters as e:
lexing_incomplete = True
pass # This may happen when the partial code has an ignore terminal
except EOFError as e:
pass

return lexer_tokens
return lexer_tokens, lexing_incomplete


class PythonIndenter(Indenter):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grammar_go.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from syncode.parse_result import AcceptSequence, RemainderState

go_grammar = Grammar('go')
inc_parser = create_parser(go_grammar)
inc_parser = create_parser(go_grammar, ignore_whitespace=True)

class TestGoParser(unittest.TestCase):

Expand Down
Loading

0 comments on commit a88b744

Please sign in to comment.