diff --git a/notebooks/tests/builtin_grammar.ipynb b/notebooks/tests/builtin_grammar.ipynb index 0231f7f..10fe62b 100644 --- a/notebooks/tests/builtin_grammar.ipynb +++ b/notebooks/tests/builtin_grammar.ipynb @@ -2,18 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "from syncode import SyncodeLogitsProcessor\n", @@ -75,8 +66,6 @@ } ], "source": [ - "# grammar_str = \"python\"\n", - "# grammar_str = \"go\"\n", "grammar_str = \"java\"\n", "\n", "grammar = Grammar(grammar_str)\n", @@ -105,6 +94,242 @@ "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n", "print(\"[OUTPUT]\", output_str)" ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", + "\n", + "Cutting Knowledge Date: December 2023\n", + "Today Date: 26 Jul 2024\n", + "\n", + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n", + "\n", + "Write a python function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", + "\n", + " \n", + "\n", + "--------------------------------------------------\n", + "Parsing failed! Falling back to unconstrained decoding.\n", + "Exception: Unexpected token Token('NAME', 'simple') at line 3, column 11.\n", + "Expected one of: \n", + "\t* __ANON_9\n", + "\t* __ANON_21\n", + "\t* AMPERSAND\n", + "\t* RPAR\n", + "\t* __ANON_4\n", + "\t* LESSTHAN\n", + "\t* IF\n", + "\t* STAR\n", + "\t* RSQB\n", + "\t* __ANON_5\n", + "\t* __ANON_17\n", + "\t* LSQB\n", + "\t* SLASH\n", + "\t* MINUS\n", + "\t* VBAR\n", + "\t* _NL\n", + "\t* FROM\n", + "\t* __ANON_20\n", + "\t* EQUAL\n", + "\t* __ANON_22\n", + "\t* __ANON_13\n", + "\t* OR\n", + "\t* SEMICOLON\n", + "\t* PLUS\n", + "\t* LPAR\n", + "\t* CIRCUMFLEX\n", + "\t* FOR\n", + "\t* __ANON_2\n", + "\t* NOT\n", + "\t* AT\n", + "\t* __ANON_10\n", + "\t* COMMA\n", + "\t* __ANON_18\n", + "\t* COLON\n", + "\t* MORETHAN\n", + "\t* AS\n", + "\t* __ANON_6\n", + "\t* ELSE\n", + "\t* __ANON_16\n", + "\t* __ANON_11\n", + "\t* DOT\n", + "\t* IN\n", + "\t* __ANON_7\n", + "\t* ASYNC\n", + "\t* IS\n", + "\t* RBRACE\n", + "\t* __ANON_8\n", + "\t* __ANON_3\n", + "\t* AND\n", + "\t* __ANON_15\n", + "\t* __ANON_19\n", + "\t* __ANON_14\n", + "\t* PERCENT\n", + "\t* __ANON_12\n", + "\t* __ANON_1\n", + "\n", + "Partial code: ### Printing 'Hello World' in Reverse\n", + "\n", + "Here is a simple Python\n", + "Parsed lexical tokens: [Token('_NL', \"### Printing 'Hello World' in Reverse\\n\\n\"), Token('NAME', 'Here'), Token('IS', 'is'), Token('NAME', 'a'), Token('NAME', 'simple')]\n", + "--------------------------------------------------\n", + "[OUTPUT] ### Printing 'Hello World' in Reverse\n", + "\n", + "Here is a simple Python function that prints 'Hello World' in reverse:\n", + "\n", + "```python\n", + "def print_hello_world_reverse():\n", + " \"\"\"\n", + " Prints 'Hello World' in reverse.\n", + " \"\"\"\n", + " print(\"Hello World\")\n", + "\n", + "# Example usage:\n", + "print_hello_world_reverse()\n", + "```\n", + "\n", + "When you run this code, it will output:\n", + "```\n", + "olleH dlroW\n", + "```\n", + "\n", + "Alternatively, you can also use slicing to reverse the string:\n", + "\n", + "```python\n", + "def print_hello_world_reverse():\n", + " \"\"\"\n", + " Prints 'Hello World' in reverse.\n", + " \"\"\"\n", + " print(\" \".join([\"H\", \"e\", \"l\", \"l\", \"o\", \" \", \"W\", \"o\", \"r\", \"l\", \"d\"])\n", + "\n", + "# Example usage:\n", + "print_hello_world_reverse()\n", + "```\n", + "\n", + "This will output:\n", + "```\n", + "olleH dlroW\n", + "```\n", + "\n", + "Note: The `join()` method is used to concatenate the elements of a list into a single string, which is then printed.\n" + ] + } + ], + "source": [ + "grammar_str = \"python\"\n", + "\n", + "grammar = Grammar(grammar_str)\n", + "syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n", + "\n", + "prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n", + "messages = [{\"role\": \"user\", \"content\": prompt}]\n", + "prompt = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + " )\n", + "print(\"[PROMPT]\", prompt, \"\\n\")\n", + "\n", + "syncode_logits_processor.reset(prompt)\n", + "\n", + "inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n", + "\n", + "attention_mask = torch.ones_like(inputs)\n", + "output = model.generate(\n", + " inputs,\n", + " attention_mask=attention_mask,\n", + " max_length=512, \n", + " num_return_sequences=1, \n", + " pad_token_id=tokenizer.eos_token_id, \n", + " logits_processor=[syncode_logits_processor]\n", + " )\n", + "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n", + "print(\"[OUTPUT]\", output_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", + "\n", + "Cutting Knowledge Date: December 2023\n", + "Today Date: 26 Jul 2024\n", + "\n", + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n", + "\n", + "Write a go function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", + "\n", + " \n", + "\n", + "[OUTPUT] // \n", + "\n", + "package main\n", + "\n", + "import (\n", + " \"fmt\" // Import the fmt package\n", + ")\n", + "\n", + "// Function to print 'hello world' in reverse\n", + "func printHelloWorld() {\n", + " // Declare a variable to hold the string 'hello world'\n", + " var s string = \"hello world\" // Define the string\n", + " // Use string reverse() to reverse the string\n", + " var reversed string = strings.Reverses(s) // Reverse the string\n", + " // Print the reversed string\n", + " fmt.Println(reversed) // Print the reversed string\n", + "}\n", + "\n", + "func main() {\n", + " // Call the function to print 'hello world' in reverse\n", + " printHelloWorld() // Call the function\n", + "} \n", + "\n", + "// Note: The string reverse() function in Go returns a string slice, not a string.\n", + "// If you want to convert the string slice to a string, you can use the string slice's string() method.\n", + "// Here, we are using the Reverses() function which returns a string slice.\n" + ] + } + ], + "source": [ + "grammar_str = \"go\"\n", + "\n", + "grammar = Grammar(grammar_str)\n", + "syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n", + "\n", + "prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n", + "messages = [{\"role\": \"user\", \"content\": prompt}]\n", + "prompt = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + " )\n", + "print(\"[PROMPT]\", prompt, \"\\n\")\n", + "\n", + "syncode_logits_processor.reset(prompt)\n", + "\n", + "inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n", + "\n", + "attention_mask = torch.ones_like(inputs)\n", + "output = model.generate(\n", + " inputs,\n", + " attention_mask=attention_mask,\n", + " max_length=512, \n", + " num_return_sequences=1, \n", + " pad_token_id=tokenizer.eos_token_id, \n", + " logits_processor=[syncode_logits_processor]\n", + " )\n", + "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n", + "print(\"[OUTPUT]\", output_str)" + ] } ], "metadata": {