Skip to content

Commit

Permalink
RL/DPO (#935)
Browse files Browse the repository at this point in the history
* ipo-dpo trainer

* fix missing abstract method

* chatml template, grad checkpointing kwargs support

* fix steps calc for RL and add dataloader kwargs

* wip to fix dpo and start ppo

* more fixes

* refactor to generalize map fn

* fix dataset loop and handle argilla pref dataset

* set training args

* load reference model on seperate gpu if more than one device

* no auto upload to hub for dpo, don't add lora adapters to ref model for dpo

* fixes for rl training

* support for ipo from yaml

* set dpo training args from the config, add tests

* chore: lint

* set sequence_len for model in test

* add RLHF docs
  • Loading branch information
winglian committed Jan 4, 2024
1 parent 59b2d30 commit f243c21
Show file tree
Hide file tree
Showing 11 changed files with 388 additions and 6 deletions.
35 changes: 35 additions & 0 deletions docs/rlhf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# RLHF (Beta)

### Overview

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:

- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- Direct Preference Optimization (DPO)
- Identity Preference Optimization (IPO)


### RLHF using Axolotl

[!IMPORTANT]
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.

The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML

#### DPO
```yaml
rl: true
datasets:
- path: Intel/orca_dpo_pairs
split: train
type: intel_apply_chatml
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: argilla_apply_chatml
```
#### IPO
```yaml
rl: ipo
```
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,5 @@ tensorboard
s3fs
gcsfs
# adlfs

trl @ git+https://github.com/huggingface/trl.git@main
90 changes: 90 additions & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib
import logging
import math
import os
import random
import sys
Expand All @@ -16,6 +17,7 @@
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
Expand Down Expand Up @@ -325,6 +327,94 @@ def load_datasets(
)


def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_datasets: List[Any] = []
for i, ds_cfg in enumerate(cfg.datasets):
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
# eval_dataset = load_dataset(
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
# )
eval_dataset = None

def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
return sample

def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample

def apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample

def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample

for i, data_set in enumerate(train_datasets):
_type = cfg.datasets[i]["type"]
ds_type_fn = locals()[_type]
train_datasets[i] = data_set.map(ds_type_fn)
train_dataset = concatenate_datasets(train_datasets)

# eval_dataset = eval_dataset.map(intel_apply_chatml)

total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)

return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)


def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
Expand All @@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)


Expand Down
103 changes: 103 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_utils import seed_worker
from trl import DPOTrainer

from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
Expand Down Expand Up @@ -420,12 +421,21 @@ class TrainerBuilderBase(abc.ABC):

_train_dataset = None
_eval_dataset = None
_model_ref = None

def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer

@property
def model_ref(self):
return self._model_ref

@model_ref.setter
def model_ref(self, model):
self._model_ref = model

@property
def train_dataset(self):
return self._train_dataset
Expand Down Expand Up @@ -827,3 +837,96 @@ def build_collator(self, **kwargs):
return_tensors="pt",
**kwargs,
)


class HFDPOTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for DPO Trainer
"""

def get_callbacks(self):
callbacks = []
return callbacks

def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
return callbacks

def build_training_arguments(self, total_num_steps):
training_args_kwargs = {}
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"dataloader_num_workers",
"dataloader_pin_memory",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=total_num_steps,
remove_unused_columns=False,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
evaluation_strategy="no",
# eval_steps=self.cfg.eval_steps,
save_strategy="steps",
save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
bf16=True,
gradient_checkpointing=self.cfg.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_first_step=True,
logging_steps=1,
optim=self.cfg.optimizer,
save_total_limit=self.cfg.save_total_limit or 5,
**training_args_kwargs,
)

return training_args

def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing

dpo_trainer = DPOTrainer(
self.model,
self.model_ref,
args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset,
# eval_dataset=self.eval_dataset,
eval_dataset=None,
tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
**dpo_trainer_kwargs,
)

return dpo_trainer


class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""

def get_callbacks(self):
callbacks = []
return callbacks

def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
return callbacks

def build(self, total_num_steps):
# build PPOConfig
pass
Empty file.
66 changes: 66 additions & 0 deletions src/axolotl/core/trainers/trl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
module for TRL PPO training
"""
import torch
from tqdm import tqdm
from trl import PPOTrainer


class TRLPPOTrainer(PPOTrainer):
"""
wrapper for ppo trainer to handle customizations
"""

def train(
self,
reward_pipe,
resume_from_checkpoint=None, # pylint: disable=unused-argument
):
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 32,
}
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": 16,
}

for epoch, batch in tqdm( # pylint: disable=unused-variable
enumerate(self.dataloader)
):
query_tensors = batch["input_ids"]

# generate model response
response_tensors, ref_response_tensors = self.generate(
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)

# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
ref_rewards = [
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
]
batch["ref_rewards"] = ref_rewards

# Run PPO step
stats = self.step(query_tensors, response_tensors, rewards)
self.log_stats(
stats,
batch,
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)
8 changes: 7 additions & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def train(
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model_ref = None
if cfg.rl:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)

safe_serialization = cfg.save_safetensors is True

Expand All @@ -83,7 +89,7 @@ def train(
freeze_parameters_except(model, cfg.unfrozen_parameters)

trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps
)

if hasattr(model, "config"):
Expand Down
Loading

0 comments on commit f243c21

Please sign in to comment.