Skip to content

Commit

Permalink
minor KTO setting changes + KL batch size (#2153)
Browse files Browse the repository at this point in the history
* add argument for dropout

* increase default lr

* change default lr in examples

* fix bug in calculation of KL batch size

* KL batch size should be args.per_device_train_batch_size

* Update kto_trainer.mdx with hparam recs

* typo

* allow dropout to be disabled

* update lr in sample scrippt

* Update kto_config.py

* Update trl/trainer/kto_trainer.py

* Update docs/source/kto_trainer.mdx

---------

Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
kawine and kashif authored Oct 6, 2024
1 parent 4799ba4 commit f05c3fa
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
12 changes: 9 additions & 3 deletions docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ For a full example have a look at [`examples/scripts/kto.py`].

Depending on how good your base model is, you may or may not need to do SFT before KTO.
This is different from standard RLHF and DPO, which always require SFT.
You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below).

## Expected dataset format

Expand Down Expand Up @@ -51,7 +52,8 @@ kto_dataset_dict = {
```

where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
In theory, the dataset must contain at least one desirable and one undesirable completion; however, some people have had success running KTO on _only_ desirable or undesirable data (in the latter case, it is best to use a conservative learning rate).


## Expected model format
Expand All @@ -61,13 +63,17 @@ The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that

For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.

The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
The `beta` refers to the hyperparameter that controls how quickly the loss saturates, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).

The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.

<Tip>
It is strongly recommended you use a learning rate between `5e-7` and `5e-6` with an effective batch size between `8` and `32`, for both LoRA and full finetuning. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, using smaller batch sizes and/or more training epochs will give you better results.
Every choice of `beta` has a maximum learning rate it will tolerate before learning degenerates. For the default `beta = 0.1', this learning rate is `1e-6` for most models. The lower the beta is, the lower your learning rate should be. In general, we strongly recommend a learning rate between `5e-7` and `5e-6`. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, use more epochs.
</Tip>

<Tip>
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
</Tip>

```py
Expand Down
5 changes: 4 additions & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ class KTOConfig(TrainingArguments):
from a string.
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
"""

learning_rate: float = 5e-7
learning_rate: float = 1e-6
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
Expand All @@ -90,6 +92,7 @@ class KTOConfig(TrainingArguments):
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
disable_dropout: bool = True
precompute_ref_log_probs: bool = False
model_init_kwargs: Optional[Dict[str, Any]] = None
ref_model_init_kwargs: Optional[Dict[str, Any]] = None
Expand Down
27 changes: 13 additions & 14 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@


def _get_kl_dataset(batch: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
"""Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions."""
"""
Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the same set as the matched
outputs y used to estimate the rewards in that batch, just paired with different x.
"""
batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1]
batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1]
return batch
Expand Down Expand Up @@ -514,10 +518,10 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# disable dropout in the model and reference model
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

self.loss_type = args.loss_type
self.max_length = max_length
Expand Down Expand Up @@ -601,22 +605,17 @@ def make_inputs_require_grad(module, input, output):

# Get KL datasets if needed
if self.calculate_KL:
total_batch_size = (
max(torch.cuda.device_count(), 1)
* args.per_device_train_batch_size
* args.gradient_accumulation_steps
)
if total_batch_size <= 1:
if args.per_device_train_batch_size <= 1:
raise ValueError(
"Batch size is 1 (too small). KTO will not work properly because the KL term will be equivalent to the implied reward."
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
)

# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
train_kl_dataset = train_dataset.map(
_get_kl_dataset,
batched=True,
batch_size=total_batch_size,
batch_size=args.per_device_train_batch_size,
num_proc=args.dataset_num_proc,
desc="Extracting KL train dataset",
)
Expand All @@ -638,7 +637,7 @@ def make_inputs_require_grad(module, input, output):
eval_kl_dataset = eval_dataset.map(
_get_kl_dataset,
batched=True,
batch_size=total_batch_size,
batch_size=args.per_device_train_batch_size,
num_proc=args.dataset_num_proc,
desc="Extracting eval KL dataset",
)
Expand Down

0 comments on commit f05c3fa

Please sign in to comment.