diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx
index 91c6ea69b1..635676b85e 100644
--- a/docs/source/kto_trainer.mdx
+++ b/docs/source/kto_trainer.mdx
@@ -7,6 +7,7 @@ For a full example have a look at [`examples/scripts/kto.py`].
Depending on how good your base model is, you may or may not need to do SFT before KTO.
This is different from standard RLHF and DPO, which always require SFT.
+You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below).
## Expected dataset format
@@ -51,7 +52,8 @@ kto_dataset_dict = {
```
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
-A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
+A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
+In theory, the dataset must contain at least one desirable and one undesirable completion; however, some people have had success running KTO on _only_ desirable or undesirable data (in the latter case, it is best to use a conservative learning rate).
## Expected model format
@@ -61,13 +63,17 @@ The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
-The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
+The `beta` refers to the hyperparameter that controls how quickly the loss saturates, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
-It is strongly recommended you use a learning rate between `5e-7` and `5e-6` with an effective batch size between `8` and `32`, for both LoRA and full finetuning. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, using smaller batch sizes and/or more training epochs will give you better results.
+Every choice of `beta` has a maximum learning rate it will tolerate before learning degenerates. For the default `beta = 0.1', this learning rate is `1e-6` for most models. The lower the beta is, the lower your learning rate should be. In general, we strongly recommend a learning rate between `5e-7` and `5e-6`. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, use more epochs.
+
+
+
+Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
```py
diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py
index 4cac67e915..3acca53bde 100644
--- a/trl/trainer/kto_config.py
+++ b/trl/trainer/kto_config.py
@@ -75,9 +75,11 @@ class KTOConfig(TrainingArguments):
from a string.
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
+ disable_dropout (`bool`, *optional*, defaults to `True`):
+ Whether to disable dropout in the model.
"""
- learning_rate: float = 5e-7
+ learning_rate: float = 1e-6
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
@@ -90,6 +92,7 @@ class KTOConfig(TrainingArguments):
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
+ disable_dropout: bool = True
precompute_ref_log_probs: bool = False
model_init_kwargs: Optional[Dict[str, Any]] = None
ref_model_init_kwargs: Optional[Dict[str, Any]] = None
diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py
index b670db9c6e..0652a25239 100644
--- a/trl/trainer/kto_trainer.py
+++ b/trl/trainer/kto_trainer.py
@@ -73,7 +73,11 @@
def _get_kl_dataset(batch: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
- """Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions."""
+ """
+ Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
+ For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the same set as the matched
+ outputs y used to estimate the rewards in that batch, just paired with different x.
+ """
batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1]
batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1]
return batch
@@ -514,10 +518,10 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False
- # disable dropout in the model and reference model
- disable_dropout_in_model(model)
- if self.ref_model is not None:
- disable_dropout_in_model(self.ref_model)
+ if args.disable_dropout:
+ disable_dropout_in_model(model)
+ if self.ref_model is not None:
+ disable_dropout_in_model(self.ref_model)
self.loss_type = args.loss_type
self.max_length = max_length
@@ -601,14 +605,9 @@ def make_inputs_require_grad(module, input, output):
# Get KL datasets if needed
if self.calculate_KL:
- total_batch_size = (
- max(torch.cuda.device_count(), 1)
- * args.per_device_train_batch_size
- * args.gradient_accumulation_steps
- )
- if total_batch_size <= 1:
+ if args.per_device_train_batch_size <= 1:
raise ValueError(
- "Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
+ "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
)
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
@@ -616,7 +615,7 @@ def make_inputs_require_grad(module, input, output):
train_kl_dataset = train_dataset.map(
_get_kl_dataset,
batched=True,
- batch_size=total_batch_size,
+ batch_size=args.per_device_train_batch_size,
num_proc=args.dataset_num_proc,
desc="Extracting KL train dataset",
)
@@ -638,7 +637,7 @@ def make_inputs_require_grad(module, input, output):
eval_kl_dataset = eval_dataset.map(
_get_kl_dataset,
batched=True,
- batch_size=total_batch_size,
+ batch_size=args.per_device_train_batch_size,
num_proc=args.dataset_num_proc,
desc="Extracting eval KL dataset",
)