diff --git a/ppo_training_glm.ipynb b/ppo_training_glm.ipynb new file mode 100644 index 0000000..33ee273 --- /dev/null +++ b/ppo_training_glm.ipynb @@ -0,0 +1,445 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModel\n", + "from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## init reward model " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from reward_model import RewardModel\n", + "from transformers import AutoTokenizer\n", + "from peft import PeftModel\n", + "from torch.nn.utils import skip_init\n", + "import torch\n", + "\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True)\n", + "\n", + "\n", + "reward_model = RewardModel.from_pretrained(\"THUDM/chatglm-6b\", load_in_8bit=True, device_map='auto')\n", + "\n", + "## load score weight\n", + "\n", + "reward_model = PeftModel.from_pretrained(reward_model, './output/reward_model/', load_in_8bit=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "weight torch.Size([1, 4096]) tensor([0.0771, 0.0723, 0.1037, 0.1068, 0.0667], device='cuda:0',\n", + " dtype=torch.float16) tensor([0.0782, 0.0766, 0.0637, 0.0989, 0.1059], device='cuda:0',\n", + " dtype=torch.float16) tensor(0.0939, device='cuda:0', dtype=torch.float16) tensor(0.0825, device='cuda:0', dtype=torch.float16)\n" + ] + } + ], + "source": [ + "class CastOutputToHalf(torch.nn.Sequential):\n", + " def forward(self, x):\n", + " return super().forward(x).half()\n", + "\n", + "\n", + "reward_model.gradient_checkpointing_disable()\n", + "\n", + "reward_model.base_model.model.score.load_state_dict(torch.load(\"output/reward_model/score.bin\"))\n", + "\n", + "for k, p in reward_model.base_model.model.score.named_parameters():\n", + " print(k, p.shape, p[0, :5], p[0, -5:], p[0][:20].mean(), p[0][-20:].mean())\n", + "\n", + "# reward_model.score = CastOutputToHalf(reward_model.score)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "weight torch.Size([1, 4096]) tensor([0.0771, 0.0723, 0.1037, 0.1068, 0.0667], device='cuda:0',\n", + " dtype=torch.float16) tensor([0.0782, 0.0766, 0.0637, 0.0989, 0.1059], device='cuda:0',\n", + " dtype=torch.float16) tensor(0.0939, device='cuda:0', dtype=torch.float16) tensor(0.0825, device='cuda:0', dtype=torch.float16)\n" + ] + } + ], + "source": [ + "for k, p in reward_model.named_parameters():\n", + " p.requires_grad = False\n", + "\n", + "for k, p in reward_model.score.named_parameters():\n", + " print(k, p.shape, p[0, :5], p[0, -5:], p[0][:20].mean(), p[0][-20:].mean())\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## init actor" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.\n", + "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n", + "Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fb2b1ec062f54a38b5dc0768b2c4d98c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/8 [00:00