Skip to content

Commit

Permalink
Merge pull request #139 from uiuc-focal-lab/instruct
Browse files Browse the repository at this point in the history
Add Go and Python example in the builtin grammar notebook
  • Loading branch information
shubhamugare authored Dec 25, 2024
2 parents 12a5e28 + 2c542d2 commit fb0dc52
Showing 1 changed file with 238 additions and 13 deletions.
251 changes: 238 additions & 13 deletions notebooks/tests/builtin_grammar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"from syncode import SyncodeLogitsProcessor\n",
Expand Down Expand Up @@ -75,8 +66,6 @@
}
],
"source": [
"# grammar_str = \"python\"\n",
"# grammar_str = \"go\"\n",
"grammar_str = \"java\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
Expand Down Expand Up @@ -105,6 +94,242 @@
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Write a python function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
" \n",
"\n",
"--------------------------------------------------\n",
"Parsing failed! Falling back to unconstrained decoding.\n",
"Exception: Unexpected token Token('NAME', 'simple') at line 3, column 11.\n",
"Expected one of: \n",
"\t* __ANON_9\n",
"\t* __ANON_21\n",
"\t* AMPERSAND\n",
"\t* RPAR\n",
"\t* __ANON_4\n",
"\t* LESSTHAN\n",
"\t* IF\n",
"\t* STAR\n",
"\t* RSQB\n",
"\t* __ANON_5\n",
"\t* __ANON_17\n",
"\t* LSQB\n",
"\t* SLASH\n",
"\t* MINUS\n",
"\t* VBAR\n",
"\t* _NL\n",
"\t* FROM\n",
"\t* __ANON_20\n",
"\t* EQUAL\n",
"\t* __ANON_22\n",
"\t* __ANON_13\n",
"\t* OR\n",
"\t* SEMICOLON\n",
"\t* PLUS\n",
"\t* LPAR\n",
"\t* CIRCUMFLEX\n",
"\t* FOR\n",
"\t* __ANON_2\n",
"\t* NOT\n",
"\t* AT\n",
"\t* __ANON_10\n",
"\t* COMMA\n",
"\t* __ANON_18\n",
"\t* COLON\n",
"\t* MORETHAN\n",
"\t* AS\n",
"\t* __ANON_6\n",
"\t* ELSE\n",
"\t* __ANON_16\n",
"\t* __ANON_11\n",
"\t* DOT\n",
"\t* IN\n",
"\t* __ANON_7\n",
"\t* ASYNC\n",
"\t* IS\n",
"\t* RBRACE\n",
"\t* __ANON_8\n",
"\t* __ANON_3\n",
"\t* AND\n",
"\t* __ANON_15\n",
"\t* __ANON_19\n",
"\t* __ANON_14\n",
"\t* PERCENT\n",
"\t* __ANON_12\n",
"\t* __ANON_1\n",
"\n",
"Partial code: ### Printing 'Hello World' in Reverse\n",
"\n",
"Here is a simple Python\n",
"Parsed lexical tokens: [Token('_NL', \"### Printing 'Hello World' in Reverse\\n\\n\"), Token('NAME', 'Here'), Token('IS', 'is'), Token('NAME', 'a'), Token('NAME', 'simple')]\n",
"--------------------------------------------------\n",
"[OUTPUT] ### Printing 'Hello World' in Reverse\n",
"\n",
"Here is a simple Python function that prints 'Hello World' in reverse:\n",
"\n",
"```python\n",
"def print_hello_world_reverse():\n",
" \"\"\"\n",
" Prints 'Hello World' in reverse.\n",
" \"\"\"\n",
" print(\"Hello World\")\n",
"\n",
"# Example usage:\n",
"print_hello_world_reverse()\n",
"```\n",
"\n",
"When you run this code, it will output:\n",
"```\n",
"olleH dlroW\n",
"```\n",
"\n",
"Alternatively, you can also use slicing to reverse the string:\n",
"\n",
"```python\n",
"def print_hello_world_reverse():\n",
" \"\"\"\n",
" Prints 'Hello World' in reverse.\n",
" \"\"\"\n",
" print(\" \".join([\"H\", \"e\", \"l\", \"l\", \"o\", \" \", \"W\", \"o\", \"r\", \"l\", \"d\"])\n",
"\n",
"# Example usage:\n",
"print_hello_world_reverse()\n",
"```\n",
"\n",
"This will output:\n",
"```\n",
"olleH dlroW\n",
"```\n",
"\n",
"Note: The `join()` method is used to concatenate the elements of a list into a single string, which is then printed.\n"
]
}
],
"source": [
"grammar_str = \"python\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
"\n",
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"prompt = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
"\n",
"syncode_logits_processor.reset(prompt)\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
"\n",
"attention_mask = torch.ones_like(inputs)\n",
"output = model.generate(\n",
" inputs,\n",
" attention_mask=attention_mask,\n",
" max_length=512, \n",
" num_return_sequences=1, \n",
" pad_token_id=tokenizer.eos_token_id, \n",
" logits_processor=[syncode_logits_processor]\n",
" )\n",
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 26 Jul 2024\n",
"\n",
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Write a go function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
" \n",
"\n",
"[OUTPUT] // \n",
"\n",
"package main\n",
"\n",
"import (\n",
" \"fmt\" // Import the fmt package\n",
")\n",
"\n",
"// Function to print 'hello world' in reverse\n",
"func printHelloWorld() {\n",
" // Declare a variable to hold the string 'hello world'\n",
" var s string = \"hello world\" // Define the string\n",
" // Use string reverse() to reverse the string\n",
" var reversed string = strings.Reverses(s) // Reverse the string\n",
" // Print the reversed string\n",
" fmt.Println(reversed) // Print the reversed string\n",
"}\n",
"\n",
"func main() {\n",
" // Call the function to print 'hello world' in reverse\n",
" printHelloWorld() // Call the function\n",
"} \n",
"\n",
"// Note: The string reverse() function in Go returns a string slice, not a string.\n",
"// If you want to convert the string slice to a string, you can use the string slice's string() method.\n",
"// Here, we are using the Reverses() function which returns a string slice.\n"
]
}
],
"source": [
"grammar_str = \"go\"\n",
"\n",
"grammar = Grammar(grammar_str)\n",
"syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
"\n",
"prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"prompt = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
" )\n",
"print(\"[PROMPT]\", prompt, \"\\n\")\n",
"\n",
"syncode_logits_processor.reset(prompt)\n",
"\n",
"inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
"\n",
"attention_mask = torch.ones_like(inputs)\n",
"output = model.generate(\n",
" inputs,\n",
" attention_mask=attention_mask,\n",
" max_length=512, \n",
" num_return_sequences=1, \n",
" pad_token_id=tokenizer.eos_token_id, \n",
" logits_processor=[syncode_logits_processor]\n",
" )\n",
"output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
"print(\"[OUTPUT]\", output_str)"
]
}
],
"metadata": {
Expand Down

0 comments on commit fb0dc52

Please sign in to comment.