-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b8e5780
commit 9d9eb46
Showing
1 changed file
with
254 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "5a1b6d40-c917-4f30-adf3-c79a50cbc1be", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%capture\n", | ||
"!python -m pip install --upgrade pip\n", | ||
"!pip install -r requirements.txt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "4552c1ab-b20f-48fc-aa4d-3bf1576261ec", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from tqdm import tqdm\n", | ||
"import json\n", | ||
"\n", | ||
"import torch\n", | ||
"import math\n", | ||
"\n", | ||
"from transformers import AutoTokenizer\n", | ||
"from datasets import load_dataset\n", | ||
"\n", | ||
"from trainer.SFTTrainer import train\n", | ||
"from model.args import MOEModelArgs\n", | ||
"from model.KANamav5 import KANamav5\n", | ||
"\n", | ||
"from utils import load_model, quick_inference\n", | ||
"\n", | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "77e0b467-149e-4247-bcf6-792ee89fba9f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def lr_lambda(current_step: int, max_steps: int=50000, warmup_steps: int=40, lr_scheduler_type: str=\"cosine\"):\n", | ||
" if current_step < warmup_steps:\n", | ||
" return current_step / warmup_steps\n", | ||
"\n", | ||
" annealing_steps = max_steps - warmup_steps\n", | ||
"\n", | ||
" if annealing_steps <= 0:\n", | ||
" annealing_steps = 1\n", | ||
"\n", | ||
" progress = (current_step - warmup_steps) / annealing_steps\n", | ||
" if lr_scheduler_type == \"cosine\":\n", | ||
" new_learning_rate = 0.5 * (1.0 + math.cos(math.pi * progress))\n", | ||
" elif lr_scheduler_type == \"sinus\":\n", | ||
" new_learning_rate = 0.5 * (1.0 + math.sin(math.pi * progress))\n", | ||
" else:\n", | ||
" new_learning_rate = 1.0\n", | ||
" return new_learning_rate" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "2a900a88-b089-4089-beaa-474641957678", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[LOADING TOKENIZER]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B\")\n", | ||
"print(\"[TOKENIZER LOADED]\")\n", | ||
"if tokenizer.pad_token is None:\n", | ||
" tokenizer.pad_token = tokenizer.eos_token\n", | ||
"\n", | ||
"\n", | ||
"MOEModelArgs.vocab_size = tokenizer.vocab_size\n", | ||
"MOEModelArgs.pad_id = tokenizer.pad_token_id\n", | ||
"MOEModelArgs.num_experts_per_tok = 4\n", | ||
"MOEModelArgs.max_batch_size = 100\n", | ||
"MOEModelArgs.max_seq_len = 128\n", | ||
"MOEModelArgs.num_experts = 14\n", | ||
"MOEModelArgs.n_layers = 18\n", | ||
"MOEModelArgs.dim = 256\n", | ||
"MOEModelArgs.n_heads = 12\n", | ||
"MOEModelArgs.n_kv_heads = 6\n", | ||
"# MOEModelArgs.use_kan = False\n", | ||
"# MOEModelArgs.use_softmax_temp_proj = False" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8d4fab82-866a-426e-91dc-5d424e5dc491", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "e09ca08e820741309f9c0792a4a7c96b", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Resolving data files: 0%| | 0/23781 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "946050163a5d4e569072ef57ff221ae9", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Resolving data files: 0%| | 0/250 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Processing dataset:\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Tokenizing dataset: 1103585it [1:02:45, 329.69it/s]" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Load the dataset (replace the old one)\n", | ||
"# use the new dataset \"sample-10BT\" (adjust as needed)\n", | ||
"# dataset = load_dataset(\"HuggingFaceFW/fineweb\", name=\"CC-MAIN-2024-18\", split=\"train\", streaming=True)\n", | ||
"# BEE-spoke-data/fineweb-1000_64k\n", | ||
"# BEE-spoke-data/fineweb-100_128k\n", | ||
"# pszemraj/fineweb-1k_long\n", | ||
"\n", | ||
"dataset = load_dataset(\"BEE-spoke-data/fineweb-1000_64k\", split=\"train\", streaming=True)\n", | ||
"\n", | ||
"# List to hold all tokenized data\n", | ||
"tokenized_data = []\n", | ||
"\n", | ||
"# Process the dataset\n", | ||
"print(\"\\nProcessing dataset:\")\n", | ||
"for data in tqdm(dataset, desc=\"Tokenizing dataset\"):\n", | ||
" # Tokenize the current entry without truncation or padding\n", | ||
" tokens = tokenizer(data['text'], truncation=False, padding=False)['input_ids']\n", | ||
" # Add the tokenized IDs to the list\n", | ||
" tokenized_data.append(torch.tensor(tokens))\n", | ||
"\n", | ||
"# Concatenate all tokenized sequences into a single tensor\n", | ||
"data = torch.cat(tokenized_data, dim=0).unsqueeze(0) # unsqueeze to add a batch dimension\n", | ||
"\n", | ||
"# Define train/val split based on the size of the dataset\n", | ||
"n = int(0.9 * data.size(1)) # data.size(1) because sequences are concatenated along dimension 1\n", | ||
"train_data = data[:, :n].to(device)\n", | ||
"val_data = data[:, n:].to(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "76e21153-a15c-42b7-bfd2-c9d7c1e69f29", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print(\"\\n[LOADING MODEL]\\n\")\n", | ||
"model = KANamav5(MOEModelArgs, device=device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "507cb062-ed03-4df7-856e-9c23c1932e1c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Starting sequence (as tokens)\n", | ||
"initial_text = \"Once upon a time\"\n", | ||
"initial_tokens = tokenizer(initial_text, return_tensors=\"pt\").input_ids.to(device)\n", | ||
"\n", | ||
"# Perform inference\n", | ||
"generated_tokens, generated_text = quick_inference(model, initial_tokens, max_new_tokens=50, tokenizer=tokenizer)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "dabca877-550d-4e93-a845-46dda978702e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)\n", | ||
"scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n", | ||
"\n", | ||
"print(\"\\n[TRAINING MODEL]\\n\")\n", | ||
"new_model = train(\n", | ||
" model=model,\n", | ||
" optimizer=optimizer,\n", | ||
" train_data=train_data,\n", | ||
" val_data=val_data,\n", | ||
" scheduler=scheduler,\n", | ||
" save_model_name=\"KANama-medium\",\n", | ||
" max_steps=100000,\n", | ||
" loss_interval=5,\n", | ||
" eval_interval=50,\n", | ||
" device=device\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |