From f243c2186d1575d25fd2b62ef7f3e3d22fd03db6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 4 Jan 2024 18:21:25 -0500 Subject: [PATCH] RL/DPO (#935) * ipo-dpo trainer * fix missing abstract method * chatml template, grad checkpointing kwargs support * fix steps calc for RL and add dataloader kwargs * wip to fix dpo and start ppo * more fixes * refactor to generalize map fn * fix dataset loop and handle argilla pref dataset * set training args * load reference model on seperate gpu if more than one device * no auto upload to hub for dpo, don't add lora adapters to ref model for dpo * fixes for rl training * support for ipo from yaml * set dpo training args from the config, add tests * chore: lint * set sequence_len for model in test * add RLHF docs --- docs/rlhf.md | 35 +++++++++ requirements.txt | 2 + src/axolotl/cli/__init__.py | 90 ++++++++++++++++++++++ src/axolotl/cli/train.py | 6 +- src/axolotl/core/trainer_builder.py | 103 ++++++++++++++++++++++++++ src/axolotl/core/trainers/__init__.py | 0 src/axolotl/core/trainers/trl.py | 66 +++++++++++++++++ src/axolotl/train.py | 8 +- src/axolotl/utils/models.py | 16 +++- src/axolotl/utils/trainer.py | 9 ++- tests/core/test_trainer_builder.py | 59 +++++++++++++++ 11 files changed, 388 insertions(+), 6 deletions(-) create mode 100644 docs/rlhf.md create mode 100644 src/axolotl/core/trainers/__init__.py create mode 100644 src/axolotl/core/trainers/trl.py create mode 100644 tests/core/test_trainer_builder.py diff --git a/docs/rlhf.md b/docs/rlhf.md new file mode 100644 index 0000000000..371a40dbf7 --- /dev/null +++ b/docs/rlhf.md @@ -0,0 +1,35 @@ +# RLHF (Beta) + +### Overview + +Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human +feedback. Various methods include, but not limited to: + +- Proximal Policy Optimization (PPO) (not yet supported in axolotl) +- Direct Preference Optimization (DPO) +- Identity Preference Optimization (IPO) + + +### RLHF using Axolotl + +[!IMPORTANT] +This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. + +The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML + +#### DPO +```yaml +rl: true +datasets: + - path: Intel/orca_dpo_pairs + split: train + type: intel_apply_chatml + - path: argilla/ultrafeedback-binarized-preferences + split: train + type: argilla_apply_chatml +``` + +#### IPO +```yaml +rl: ipo +``` diff --git a/requirements.txt b/requirements.txt index f4df0dd671..14f6633f7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,5 @@ tensorboard s3fs gcsfs # adlfs + +trl @ git+https://github.com/huggingface/trl.git@main diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 85f6b358ac..0477ebebfb 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -2,6 +2,7 @@ import importlib import logging +import math import os import random import sys @@ -16,6 +17,7 @@ # add src to the pythonpath so we don't need to pip install this from accelerate.commands.config import config_args from art import text2art +from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer @@ -325,6 +327,94 @@ def load_datasets( ) +def load_rl_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, # pylint: disable=unused-argument +) -> TrainDatasetMeta: + train_datasets: List[Any] = [] + for i, ds_cfg in enumerate(cfg.datasets): + train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"])) + # eval_dataset = load_dataset( + # cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"] + # ) + eval_dataset = None + + def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" + sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" + return sample + + def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + def apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + return sample + + for i, data_set in enumerate(train_datasets): + _type = cfg.datasets[i]["type"] + ds_type_fn = locals()[_type] + train_datasets[i] = data_set.map(ds_type_fn) + train_dataset = concatenate_datasets(train_datasets) + + # eval_dataset = eval_dataset.map(intel_apply_chatml) + + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + def check_accelerate_default_config(): if Path(config_args.default_yaml_config_file).exists(): LOG.warning( diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 81307b6b92..2248784dff 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -12,6 +12,7 @@ check_user_token, load_cfg, load_datasets, + load_rl_datasets, print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs @@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs): parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cfg.rl: + dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4ca2877d19..1ca36eb418 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -20,6 +20,7 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_utils import seed_worker +from trl import DPOTrainer from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( @@ -420,12 +421,21 @@ class TrainerBuilderBase(abc.ABC): _train_dataset = None _eval_dataset = None + _model_ref = None def __init__(self, cfg, model, tokenizer): self.cfg = cfg self.model = model self.tokenizer = tokenizer + @property + def model_ref(self): + return self._model_ref + + @model_ref.setter + def model_ref(self, model): + self._model_ref = model + @property def train_dataset(self): return self._train_dataset @@ -827,3 +837,96 @@ def build_collator(self, **kwargs): return_tensors="pt", **kwargs, ) + + +class HFDPOTrainerBuilder(TrainerBuilderBase): + """ + Trainer factory class for DPO Trainer + """ + + def get_callbacks(self): + callbacks = [] + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + return callbacks + + def build_training_arguments(self, total_num_steps): + training_args_kwargs = {} + for arg in [ + "adam_beta1", + "adam_beta2", + "adam_epsilon", + "dataloader_num_workers", + "dataloader_pin_memory", + ]: + if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: + training_args_kwargs[arg] = getattr(self.cfg, arg) + training_args = TrainingArguments( + per_device_train_batch_size=self.cfg.micro_batch_size, + max_steps=total_num_steps, + remove_unused_columns=False, + gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, + learning_rate=self.cfg.learning_rate, + evaluation_strategy="no", + # eval_steps=self.cfg.eval_steps, + save_strategy="steps", + save_steps=self.cfg.save_steps, + output_dir=self.cfg.output_dir, + warmup_steps=self.cfg.warmup_steps, + bf16=True, + gradient_checkpointing=self.cfg.gradient_checkpointing, + gradient_checkpointing_kwargs={"use_reentrant": False}, + logging_first_step=True, + logging_steps=1, + optim=self.cfg.optimizer, + save_total_limit=self.cfg.save_total_limit or 5, + **training_args_kwargs, + ) + + return training_args + + def build(self, total_num_steps): + training_args = self.build_training_arguments(total_num_steps) + dpo_trainer_kwargs = {} + if self.cfg.rl == "ipo": + dpo_trainer_kwargs["loss_type"] = "ipo" + if self.cfg.dpo_label_smoothing: + dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + + dpo_trainer = DPOTrainer( + self.model, + self.model_ref, + args=training_args, + beta=self.cfg.dpo_beta or 0.1, + train_dataset=self.train_dataset, + # eval_dataset=self.eval_dataset, + eval_dataset=None, + tokenizer=self.tokenizer, + max_length=self.cfg.sequence_len, + max_target_length=None, + max_prompt_length=self.cfg.sequence_len, + generate_during_eval=True, + **dpo_trainer_kwargs, + ) + + return dpo_trainer + + +class HFPPOTrainerBuilder(TrainerBuilderBase): + """ + HF Factory class for PPO Trainer + """ + + def get_callbacks(self): + callbacks = [] + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + return callbacks + + def build(self, total_num_steps): + # build PPOConfig + pass diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py new file mode 100644 index 0000000000..24c0b04123 --- /dev/null +++ b/src/axolotl/core/trainers/trl.py @@ -0,0 +1,66 @@ +""" +module for TRL PPO training +""" +import torch +from tqdm import tqdm +from trl import PPOTrainer + + +class TRLPPOTrainer(PPOTrainer): + """ + wrapper for ppo trainer to handle customizations + """ + + def train( + self, + reward_pipe, + resume_from_checkpoint=None, # pylint: disable=unused-argument + ): + generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.tokenizer.eos_token_id, + "max_new_tokens": 32, + } + sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": 16, + } + + for epoch, batch in tqdm( # pylint: disable=unused-variable + enumerate(self.dataloader) + ): + query_tensors = batch["input_ids"] + + # generate model response + response_tensors, ref_response_tensors = self.generate( + query_tensors, + return_prompt=False, + generate_ref_response=True, + **generation_kwargs + ) + batch["response"] = self.tokenizer.batch_decode(response_tensors) + batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] + ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs) + ref_rewards = [ + torch.tensor(output[1]["score"]) for output in ref_pipe_outputs + ] + batch["ref_rewards"] = ref_rewards + + # Run PPO step + stats = self.step(query_tensors, response_tensors, rewards) + self.log_stats( + stats, + batch, + rewards, + columns_to_log=["query", "response", "ref_response", "ref_rewards"], + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 4e5241e4c8..e0da112528 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -61,6 +61,12 @@ def train( msg += " and peft_config..." LOG.debug(msg) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) + model_ref = None + if cfg.rl: + # load the model again for model_ref/baseline + model_ref, _ = load_model( + cfg, tokenizer, inference=cli_args.inference, reference_model=True + ) safe_serialization = cfg.save_safetensors is True @@ -83,7 +89,7 @@ def train( freeze_parameters_except(model, cfg.unfrozen_parameters) trainer = setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps + cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps ) if hasattr(model, "config"): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb2420108a..b30ffcad8c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -200,6 +200,7 @@ def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, inference: bool = False, + reference_model: bool = False, ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: """ Load a model for a given configuration and tokenizer. @@ -290,6 +291,15 @@ def load_model( model_kwargs["device_map"] = cfg.device_map model_kwargs["max_memory"] = cfg.max_memory model_kwargs["torch_dtype"] = cfg.torch_dtype + # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) if is_deepspeed_zero3_enabled(): del model_kwargs["device_map"] @@ -560,9 +570,11 @@ def load_model( if hasattr(module, "weight"): module.to(cfg.torch_dtype) - model, lora_config = load_adapter(model, cfg, cfg.adapter) + lora_config = None + if not reference_model or cfg.lora_model_dir: + model, lora_config = load_adapter(model, cfg, cfg.adapter) - if cfg.ddp and not load_in_8bit: + if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): model.to(f"cuda:{cfg.local_rank}") if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f046dd7be8..d975bb9a2d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -12,7 +12,7 @@ from datasets import set_caching_enabled from torch.utils.data import DataLoader, RandomSampler -from axolotl.core.trainer_builder import HFCausalTrainerBuilder +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.samplers import MultipackBatchSampler @@ -280,7 +280,12 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) + if cfg.rl: + trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder.model_ref = model[1] + else: + trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py new file mode 100644 index 0000000000..e8987ef452 --- /dev/null +++ b/tests/core/test_trainer_builder.py @@ -0,0 +1,59 @@ +""" +unit tests for axolotl.core.trainer_builder +""" +import pytest + +from axolotl.core.trainer_builder import HFDPOTrainerBuilder +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + + +@pytest.fixture(name="cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "LlamaTokenizer", + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.00005, + "save_steps": 100, + "output_dir": "./model-out", + "warmup_steps": 10, + "gradient_checkpointing": False, + "optimizer": "adamw_torch", + "sequence_len": 2048, + "rl": True, + "adam_beta1": 0.998, + "adam_beta2": 0.9, + "adam_epsilon": 0.00001, + "dataloader_num_workers": 1, + "dataloader_pin_memory": True, + } + ) + + +@pytest.fixture(name="tokenizer") +def fixture_tokenizer(cfg): + return load_tokenizer(cfg) + + +@pytest.fixture(name="model") +def fixture_model(cfg, tokenizer): + return load_model(cfg, tokenizer) + + +class TestHFDPOTrainerBuilder: + """ + TestCase class for DPO trainer builder + """ + + def test_build_training_arguments(self, cfg, model, tokenizer): + builder = HFDPOTrainerBuilder(cfg, model, tokenizer) + training_arguments = builder.build_training_arguments(100) + assert training_arguments.adam_beta1 == 0.998 + assert training_arguments.adam_beta2 == 0.9 + assert training_arguments.adam_epsilon == 0.00001 + assert training_arguments.dataloader_num_workers == 1 + assert training_arguments.dataloader_pin_memory is True