diff --git a/example-fineweb.ipynb b/example-fineweb.ipynb new file mode 100644 index 0000000..97beb72 --- /dev/null +++ b/example-fineweb.ipynb @@ -0,0 +1,254 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5a1b6d40-c917-4f30-adf3-c79a50cbc1be", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!python -m pip install --upgrade pip\n", + "!pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4552c1ab-b20f-48fc-aa4d-3bf1576261ec", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "import json\n", + "\n", + "import torch\n", + "import math\n", + "\n", + "from transformers import AutoTokenizer\n", + "from datasets import load_dataset\n", + "\n", + "from trainer.SFTTrainer import train\n", + "from model.args import MOEModelArgs\n", + "from model.KANamav5 import KANamav5\n", + "\n", + "from utils import load_model, quick_inference\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "77e0b467-149e-4247-bcf6-792ee89fba9f", + "metadata": {}, + "outputs": [], + "source": [ + "def lr_lambda(current_step: int, max_steps: int=50000, warmup_steps: int=40, lr_scheduler_type: str=\"cosine\"):\n", + " if current_step < warmup_steps:\n", + " return current_step / warmup_steps\n", + "\n", + " annealing_steps = max_steps - warmup_steps\n", + "\n", + " if annealing_steps <= 0:\n", + " annealing_steps = 1\n", + "\n", + " progress = (current_step - warmup_steps) / annealing_steps\n", + " if lr_scheduler_type == \"cosine\":\n", + " new_learning_rate = 0.5 * (1.0 + math.cos(math.pi * progress))\n", + " elif lr_scheduler_type == \"sinus\":\n", + " new_learning_rate = 0.5 * (1.0 + math.sin(math.pi * progress))\n", + " else:\n", + " new_learning_rate = 1.0\n", + " return new_learning_rate" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2a900a88-b089-4089-beaa-474641957678", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[LOADING TOKENIZER]\n" + ] + } + ], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B\")\n", + "print(\"[TOKENIZER LOADED]\")\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "\n", + "MOEModelArgs.vocab_size = tokenizer.vocab_size\n", + "MOEModelArgs.pad_id = tokenizer.pad_token_id\n", + "MOEModelArgs.num_experts_per_tok = 4\n", + "MOEModelArgs.max_batch_size = 100\n", + "MOEModelArgs.max_seq_len = 128\n", + "MOEModelArgs.num_experts = 14\n", + "MOEModelArgs.n_layers = 18\n", + "MOEModelArgs.dim = 256\n", + "MOEModelArgs.n_heads = 12\n", + "MOEModelArgs.n_kv_heads = 6\n", + "# MOEModelArgs.use_kan = False\n", + "# MOEModelArgs.use_softmax_temp_proj = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d4fab82-866a-426e-91dc-5d424e5dc491", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e09ca08e820741309f9c0792a4a7c96b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/23781 [00:00