diff --git a/docs/source/cpo_trainer.mdx b/docs/source/cpo_trainer.mdx
index 676825f698..6ebb3b84ca 100644
--- a/docs/source/cpo_trainer.mdx
+++ b/docs/source/cpo_trainer.mdx
@@ -2,101 +2,66 @@
[![](https://img.shields.io/badge/All_models-CPO-blue)](https://huggingface.co/models?other=cpo)
-Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to
-avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
+## Overview
-CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
+Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
-## SimPO
-The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the `CPOTrainer`. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the `CPOConfig`.
+CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
-## CPO-SimPO
-We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO Github](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the CPOConfig.
+## Quick start
-## Expected dataset format
+This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [Capybara dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
-The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
-
-- `prompt`
-- `chosen`
-- `rejected`
-
-for example:
-
-```py
-cpo_dataset_dict = {
- "prompt": [
- "hello",
- "how are you",
- "What is your name?",
- "What is your name?",
- "Which is the best programming language?",
- "Which is the best programming language?",
- "Which is the best programming language?",
- ],
- "chosen": [
- "hi nice to meet you",
- "I am fine",
- "My name is Mary",
- "My name is Mary",
- "Python",
- "Python",
- "Java",
- ],
- "rejected": [
- "leave me alone",
- "I am not fine",
- "Whats it to you?",
- "I dont have a name",
- "Javascript",
- "C++",
- "C++",
- ],
-}
-```
-where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
+
-## Expected model format
-The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
+Below is the script to train the model:
-## Using the `CPOTrainer`
-For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.
+```python
+# train_cpo.py
+from datasets import load_dataset
+from trl import CPOConfig, CPOTrainer
+from transformers import AutoModelForCausalLM, AutoTokenizer
-```py
-training_args = CPOConfig(
- beta=0.1,
-)
+model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+train_dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
-cpo_trainer = CPOTrainer(
- model,
- args=training_args,
- train_dataset=train_dataset,
- tokenizer=tokenizer,
-)
+training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10)
+trainer = CPOTrainer(model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset)
+trainer.train()
```
-After this one can then call:
-```py
-cpo_trainer.train()
-```
+Execute the script using the following command:
-## Loss functions
+```bash
+accelerate launch train_cpo.py
+```
-Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
+## Expected dataset format
-The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
+CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
-The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).
+## Example script
-### For Mixture of Experts Models: Enabling the auxiliary loss
+We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
-MOEs are the most efficient if the load is about equally distributed between experts.
-To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
+To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
-This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
-To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
+```bash
+accelerate launch examples/scripts/cpo.py \
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
+ --dataset_name trl-lib/ultrafeedback_binarized \
+ --num_train_epochs 1 \
+ --logging_steps 25 \
+ --output_dir Qwen2-0.5B-CPO
+```
-## Logging
+## Logged metrics
While training and evaluating we record the following reward metrics:
@@ -106,6 +71,34 @@ While training and evaluating we record the following reward metrics:
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
+## CPO variants
+
+### Simple Preference Optimization (SimPO)
+
+The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the [`CPOConfig`].
+
+### CPO-SimPO
+
+We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
+
+## Loss functions
+
+The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
+
+| `loss_type=` | Description |
+| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
+| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
+| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
+
+### For Mixture of Experts Models: Enabling the auxiliary loss
+
+MOEs are the most efficient if the load is about equally distributed between experts.
+To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
+
+This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
+To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
+
## CPOTrainer
[[autodoc]] CPOTrainer
diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx
index 03444d7cf1..6b8c0fbdda 100644
--- a/docs/source/dataset_formats.mdx
+++ b/docs/source/dataset_formats.mdx
@@ -194,20 +194,20 @@ unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "
Choosing the right dataset format depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset formats supported by each TRL trainer.
-| Trainer | Expected dataset format |
-| ----------------------- | ---------------------------- |
-| [`BCOTrainer`] | Unpaired preference |
-| [`CPOTrainer`] | Preference (explicit prompt) |
-| [`DPOTrainer`] | Preference (explicit prompt) |
-| [`IterativeSFTTrainer`] | Unpaired preference |
-| [`KTOTrainer`] | Unpaired preference |
-| [`NashMDTrainer`] | Prompt-only |
-| [`OnlineDPOTrainer`] | Prompt-only |
-| [`ORPOTrainer`] | Preference (explicit prompt) |
-| [`PPOv2Trainer`] | Tokenized language modeling |
-| [`RewardTrainer`] | Preference (implicit prompt) |
-| [`SFTTrainer`] | Language modeling |
-| [`XPOTrainer`] | Prompt-only |
+| Trainer | Expected dataset format |
+| ----------------------- | ------------------------------------------------------- |
+| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
+| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
+| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
+| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
+| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) |
+| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
+| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
+| [`ORPOTrainer`] | [Preference (explicit prompt)](#preference) |
+| [`PPOv2Trainer`] | Tokenized language modeling |
+| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
+| [`SFTTrainer`] | [Language modeling](#language-modeling) |
+| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py
index 341ea67cac..96192cd472 100644
--- a/examples/scripts/cpo.py
+++ b/examples/scripts/cpo.py
@@ -54,7 +54,6 @@
from dataclasses import dataclass, field
-from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
@@ -65,7 +64,7 @@
@dataclass
class ScriptArguments:
dataset_name: str = field(
- default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
+ default="trl-lib/ultrafeedback_binarized",
metadata={"help": "The name of the dataset to use."},
)
@@ -93,16 +92,6 @@ class ScriptArguments:
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
- def process(row):
- row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
- row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
- return row
-
- # Compute that only on the main process for faster data processing.
- # see: https://github.com/huggingface/trl/pull/1255
- with PartialState().local_main_process_first():
- dataset = dataset.map(process, num_proc=training_args.dataset_num_proc)
-
################
# Training
################
diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py
index 6ffb7e13f6..15f39a630b 100644
--- a/tests/test_cpo_trainer.py
+++ b/tests/test_cpo_trainer.py
@@ -37,15 +37,16 @@ def setUp(self):
@parameterized.expand(
[
- ["gpt2", "sigmoid"],
- ["t5", "hinge"],
- ["gpt2", "ipo"],
- ["t5", "ipo"],
- ["gpt2", "simpo"],
- ["t5", "simpo"],
+ ["gpt2", "sigmoid", "standard_preference"],
+ ["t5", "hinge", "standard_implicit_prompt_preference"],
+ ["gpt2", "ipo", "conversational_preference"],
+ ["t5", "ipo", "conversational_implicit_prompt_preference"],
+ ["gpt2", "simpo", "standard_preference"],
+ ["t5", "simpo", "standard_implicit_prompt_preference"],
+ ["gpt2", "hinge", "conversational_preference"],
]
)
- def test_cpo_trainer(self, name, loss_type):
+ def test_cpo_trainer(self, name, loss_type, config_name):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = CPOConfig(
output_dir=tmp_dir,
@@ -61,7 +62,7 @@ def test_cpo_trainer(self, name, loss_type):
report_to="none",
)
- dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
+ dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)
if name == "gpt2":
model = self.model
@@ -93,7 +94,15 @@ def test_cpo_trainer(self, name, loss_type):
assert not torch.equal(param, new_param)
@require_peft
- def test_cpo_trainer_with_lora(self):
+ @parameterized.expand(
+ [
+ ("standard_preference",),
+ ("standard_implicit_prompt_preference",),
+ ("conversational_preference",),
+ ("conversational_implicit_prompt_preference",),
+ ]
+ )
+ def test_cpo_trainer_with_lora(self, config_name):
from peft import LoraConfig
lora_config = LoraConfig(
@@ -118,7 +127,7 @@ def test_cpo_trainer_with_lora(self):
report_to="none",
)
- dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
+ dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)
trainer = CPOTrainer(
model=self.model,
diff --git a/trl/trainer/cpo_config.py b/trl/trainer/cpo_config.py
index f61672de4e..ac45b203e5 100644
--- a/trl/trainer/cpo_config.py
+++ b/trl/trainer/cpo_config.py
@@ -27,6 +27,9 @@ class CPOConfig(TrainingArguments):
command line.
Parameters:
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
+ [`~transformers.TrainingArguments`].
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
@@ -74,6 +77,7 @@ class CPOConfig(TrainingArguments):
Number of processes to use for processing the dataset.
"""
+ learning_rate: float = 1e-6
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py
index 738f563323..929e137062 100644
--- a/trl/trainer/cpo_trainer.py
+++ b/trl/trainer/cpo_trainer.py
@@ -42,6 +42,7 @@
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
+from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .cpo_config import CPOConfig
from .utils import (
DPODataCollatorWithPadding,
@@ -302,6 +303,17 @@ def make_inputs_require_grad(module, input, output):
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
+ # Extract the prompt if needed, and apply the chat template if needed
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
+ train_dataset = train_dataset.map(
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}, num_proc=args.dataset_num_proc
+ )
+ if eval_dataset is not None:
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
+ eval_dataset = eval_dataset.map(
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer}, num_proc=args.dataset_num_proc
+ )
+
# tokenize the dataset
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
if eval_dataset is not None: