-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
example code to use the finetuned model
- Loading branch information
Showing
2 changed files
with
217 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d", | ||
"metadata": {}, | ||
"source": [ | ||
"<table style=\"width:100%\">\n", | ||
"<tr>\n", | ||
"<td style=\"vertical-align:middle; text-align:left;\">\n", | ||
"<font size=\"2\">\n", | ||
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n", | ||
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n", | ||
"</font>\n", | ||
"</td>\n", | ||
"<td style=\"vertical-align:middle; text-align:left;\">\n", | ||
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n", | ||
"</td>\n", | ||
"</tr>\n", | ||
"</table>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f3f83194-82b9-4478-9550-5ad793467bd0", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load And Use Finetuned Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook contains minimal code to load the finetuned model that was instruction finetuned and saved in chapter 7 via [ch07.ipynb](ch07.ipynb)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"tiktoken version: 0.7.0\n", | ||
"torch version: 2.3.1\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from importlib.metadata import version\n", | ||
"\n", | ||
"pkgs = [\n", | ||
" \"tiktoken\", # Tokenizer\n", | ||
" \"torch\", # Deep learning library\n", | ||
"]\n", | ||
"for p in pkgs:\n", | ||
" print(f\"{p} version: {version(p)}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"\n", | ||
"finetuned_model_path = Path(\"gpt2-medium355M-sft.pth\")\n", | ||
"if not finetuned_model_path.exists():\n", | ||
" print(\n", | ||
" f\"Could not find '{finetuned_model_path}'.\\n\"\n", | ||
" \"Run the `ch07.ipynb` notebook to finetune and save finetuned model.\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "fb02584a-5e31-45d5-8377-794876907bc6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from gpt_download import download_and_load_gpt2\n", | ||
"from previous_chapters import GPTModel, load_weights_into_gpt\n", | ||
"\n", | ||
"\n", | ||
"BASE_CONFIG = {\n", | ||
" \"vocab_size\": 50257, # Vocabulary size\n", | ||
" \"context_length\": 1024, # Context length\n", | ||
" \"drop_rate\": 0.0, # Dropout rate\n", | ||
" \"qkv_bias\": True # Query-key-value bias\n", | ||
"}\n", | ||
"\n", | ||
"model_configs = {\n", | ||
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n", | ||
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n", | ||
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n", | ||
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n", | ||
"}\n", | ||
"\n", | ||
"CHOOSE_MODEL = \"gpt2-medium (355M)\"\n", | ||
"\n", | ||
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n", | ||
"\n", | ||
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n", | ||
"model = GPTModel(BASE_CONFIG)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"\n", | ||
"model.load_state_dict(torch.load(\"gpt2-medium355M-sft.pth\", map_location=torch.device(\"cpu\")))\n", | ||
"model.eval();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import tiktoken\n", | ||
"\n", | ||
"tokenizer = tiktoken.get_encoding(\"gpt2\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt = \"\"\"Below is an instruction that describes a task. Write a response \n", | ||
"that appropriately completes the request.\n", | ||
"\n", | ||
"### Instruction:\n", | ||
"Convert the active sentence to passive: 'The chef cooks the meal every day.'\n", | ||
"\"\"\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"The meal is cooked every day by the chef.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from previous_chapters import (\n", | ||
" generate,\n", | ||
" text_to_token_ids,\n", | ||
" token_ids_to_text\n", | ||
")\n", | ||
"\n", | ||
"def extract_response(response_text, input_text):\n", | ||
" return response_text[len(input_text):].replace(\"### Response:\", \"\").strip()\n", | ||
"\n", | ||
"torch.manual_seed(123)\n", | ||
"\n", | ||
"token_ids = generate(\n", | ||
" model=model,\n", | ||
" idx=text_to_token_ids(prompt, tokenizer),\n", | ||
" max_new_tokens=35,\n", | ||
" context_size=BASE_CONFIG[\"context_length\"],\n", | ||
" eos_id=50256\n", | ||
")\n", | ||
"\n", | ||
"response = token_ids_to_text(token_ids, tokenizer)\n", | ||
"response = extract_response(response, prompt)\n", | ||
"print(response)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |