From 751fb1d84b7815f0929803da635ed68700166c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sun, 12 Jan 2025 15:23:19 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8F=9B=EF=B8=8F=20Improve=20DPO=20configu?= =?UTF-8?q?ration=20documentation=20structure=20(#2561)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * better structure dpo config * fix tests * fix regex * add contributing guidelines --- CONTRIBUTING.md | 101 ++++++++- tests/test_dpo_trainer.py | 31 ++- trl/trainer/dpo_config.py | 407 +++++++++++++++++++------------------ trl/trainer/dpo_trainer.py | 17 +- 4 files changed, 335 insertions(+), 221 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 13983328ea..45d66be65e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,7 +38,7 @@ pip install -e .[dev] ## Fixing outstanding issues -If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#create-a-pull-request) and open a Pull Request! +If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request! ## Submitting a bug-related issue or feature request @@ -257,6 +257,105 @@ That's how `make test` is implemented (without the `pip install` line)! You can specify a smaller set of tests to test only the feature you're working on. +### Writing documentation + +High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project. + +To illustrate what good documentation looks like, here’s an example of a well-documented function: + +````python +def replicate_str(string: str, n: int, sep: str = " ") -> str: + r""" + Replicate a string `n` times with a separator. + + Args: + string (`str`): + String to replicate. + n (`int`): + Number of times to replicate the string. + sep (`str`, *optional*, defaults to `" "`): + Separator to use between each replication. + + Returns: + `str`: The replicated string. + + Examples: + ```python + >>> replicate_str("hello", 3) + "hello hello hello" + >>> replicate_str("hello", 3, sep=", ") + "hello, hello, hello" + ``` + """ + return sep.join([string] * n) +```` + +* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability. +* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate") +* **Type Annotations:** + * Always include type definitions, indicating if a parameter is optional and specifying the default value. + * Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value. + E.g., for arguments that can't be `None` and aren't required: + + ```python + foo (`int`, *optional*, defaults to `4`): + ``` + + For arguments that can be `None` and are required: + + ```python + foo (`Optional[int]`): + ``` + + for arguments that can be `None` and aren't required: + + ```python + foo (`Optional[int]`, *optional*, defaults to `None`): + ``` + +* **String Defaults:** + * Ensured that default string values are wrapped in double quotes: + + ```python + defaults to `"foo"` + ``` + +* **Dictionary Typing:** + * Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs. +* **Default Value Formatting:** + * Consistently surrounded default values with backticks for improved formatting: + + ```python + defaults to `4` + ``` + +* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability. + + ```python + def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]: + r""" + Calculates basic statistics for a given dataset. + + Args: + > Data inputs + + data (`list[float]`): + A list of numerical values to analyze. + + > Configuration parameters + + precision (`int`, *optional*, defaults to `2`): + Number of decimal places to round the results. + include_variance (`bool`, *optional*, defaults to `False`): + Whether to include the variance of the dataset in the results. + + Returns: + `dict[str, float]`: + A dictionary containing calculated statistics such as mean, median, and optionally variance. + """ + ... + ``` + ### Deprecation and Backward Compatibility Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs. diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index a786f6a41e..3413a0884e 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -457,9 +457,10 @@ def test_dpo_trainer_padding_token_is_none(self): with self.assertRaisesRegex( ValueError, - expected_regex=r"Can't find `pad_token_id` in the `processing_class`. " - r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) " - r"before instantiating the trainer.", + expected_regex=r"`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in " + r"the `processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set " + r"`tokenizer.pad_token` \(e.g., `tokenizer.pad_token = tokenizer.eos_token`\) before instantiating " + r"the trainer.", ): trainer = DPOTrainer( model=self.model, @@ -490,24 +491,16 @@ def test_dpo_trainer_w_dataset_num_proc(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") tokenizer = AutoTokenizer.from_pretrained(self.model_id) - tokenizer.pad_token = None - with self.assertRaisesRegex( - ValueError, - expected_regex=r"Can't find `pad_token_id` in the `processing_class`. " - r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) " - r"before instantiating the trainer.", - ): - trainer = DPOTrainer( - model=self.model, - ref_model=None, - args=training_args, - processing_class=tokenizer, - train_dataset=dummy_dataset["train"], - eval_dataset=dummy_dataset["test"], - ) + trainer = DPOTrainer( + model=self.model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) - trainer.train() + trainer.train() def test_tr_dpo_trainer(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b7dd6b049a..fbbaaf9bbe 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from enum import Enum from typing import Any, Optional @@ -40,16 +41,64 @@ class DPOConfig(TrainingArguments): command line. Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the + [`DPOTrainer`] is provided as a string. + ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the + [`DPOTrainer`] is provided as a string. + model_adapter_name (`str` or `None`, *optional*, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + force_use_ref_model (`bool`, *optional*, defaults to `False`): + If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set + this flag to `True`. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): + If `True`, only a specified number of logits are computed in the forward pass. This can be useful for + saving memory and speeding up training by not computing the logits for all tokens, especially in + scenarios when working with very long prompts where labels are ignored (-100). + + > Parameters that control the data preprocessing + + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + padding_value (`int` or `None`, *optional*, defaults to `None`): + Padding value to use. If `None`, the padding value of the tokenizer is used. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Padding value to use for labels. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to usewhen the prompt is too long, either `keep_end` or `keep_start`. + max_prompt_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the prompt. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion. + max_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the full sequence (prompt + completion). + padding_free (`bool`, *optional*, defaults to `False`): + Whether forward passes are performed without padding by flattening all sequences in the batch + into a single continuous sequence. This approach requires associating a `position_ids` vector to track + positional information. Currently, this is only supported with the `flash_attention_2` mechanism, as it + can handle the flattened batch structure. + precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): + Whether to precompute the log probabilities from the reference model. Setting this to `True` allows + training without needing the reference model during training, which can help reduce GPU memory usage. If + set to `False` (default), the reference model will be used during training to compute log probabilities + on-the-fly. + precompute_ref_batch_size (`int` or `None`, *optional*, defaults to `None`): + Batch size to use when precomputing reference model log probabilities. This can be set higher than the + training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for + training and `per_device_eval_batch_size` for evaluation. + + > Parameters that control the training + learning_rate (`float`, *optional*, defaults to `1e-6`): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the deviation from the reference model. Higher β means less deviation from the - reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in - the [paper](https://huggingface.co/papers/2310.12036). - label_smoothing (`float`, *optional*, defaults to `0.0`): - Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and - [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. loss_type (`str`, *optional*, defaults to `"sigmoid"`): Type of loss to use. Possible values are: @@ -66,194 +115,142 @@ class DPOConfig(TrainingArguments): - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - use_weighting (`bool`, *optional*, defaults to `False`): - Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int` or `None`, *optional*, defaults to `None`): - Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the - default data collator. - max_length (`int` or `None`, *optional*, defaults to `None`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `None`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int` or `None`, *optional*, defaults to `None`): - Maximum length of the target. This argument is required if you want to use the default data collator and - your model is an encoder-decoder. - is_encoder_decoder(`Optional[int]`, *optional*, defaults to `None`): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model and reference model. - generate_during_eval (`bool`, *optional*, defaults to `False`): - If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during - evaluation. - precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): - Whether to precompute reference model log probabilities for training and evaluation datasets. This is - useful when training without the reference model to reduce the total GPU memory needed. - precompute_ref_batch_size (`int` or `None`, *optional*, defaults to `None`): - Batch size to use when precomputing reference model log probabilities. This can be set higher than the - training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for - training and `per_device_eval_batch_size` for evaluation. - dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): - Number of processes to use for processing the dataset. - model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model - from a string. - model_adapter_name (`str` or `None`, *optional*, defaults to `None`): - Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): - Name of the reference PEFT adapter, when using LoRA with multiple adapters. - reference_free (`bool`, *optional*, defaults to `False`): - If `True`, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal - probability to all responses. - force_use_ref_model (`bool`, *optional*, defaults to `False`): - In case one passes a PEFT model for the active model and you want to use a different model for the - ref_model, set this flag to `True`. + + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in + the [paper](https://huggingface.co/papers/2310.12036). f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`): Type of f-divergence regularization function to compute divergence between policy and reference model. f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`): α coefficient in the α-divergence u^-α regularization function for DPO loss. + reference_free (`bool`, *optional*, defaults to `False`): + Whether to ignore the provided reference model and implicitly use a reference model that assigns equal + probability to all responses. + label_smoothing (`float`, *optional*, defaults to `0.0`): + Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and + [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. + use_weighting (`bool`, *optional*, defaults to `False`): + Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. + rpo_alpha (`float`, *optional*, defaults to `None`): + α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the + weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the + DPO loss. The paper recommends `rpo_alpha=1.0`. + discopop_tau (`float`, *optional*, defaults to `0.05`): + τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls + the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. sync_ref_model (`bool`, *optional*, defaults to `False`): - When set to `True`, the reference model is synchronized with the active model every `ref_model_sync_steps` - steps, using the `ref_model_mixup_alpha` parameter. This synchronization originites from the + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originites from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper. ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`): α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix between the current policy and the previous reference policy during updates. The reference policy is - updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev` - To use this parameter, you must set `sync_ref_model=True`. + updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. ref_model_sync_steps (`int`, *optional*, defaults to `64`): τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how frequently the current policy is synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`. - rpo_alpha (`float`, *optional*, defaults to `None`): - α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the - weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the - DPO loss. The paper recommends `rpo_alpha=1.0`. - discopop_tau (`float`, *optional*, defaults to `0.05`): - τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls - the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. - use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): - If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be - useful for saving memory and speeding up training by not computing the logits for all tokens, especially in - scenarios when working with very long prompts where labels are ignored (-100). - [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) - padding_free (`bool`, *optional*, defaults to `False`): - Whether forward passes are performed without padding by flattening all sequences in the batch - into a single continuous sequence. This approach requires associating a `position_ids` vector to track - positional information. Currently, this is only supported with the `flash_attention_2` mechanism, as it - can handle the flattened batch structure. + + > Parameters that control the logging + + generate_during_eval (`bool`, *optional*, defaults to `False`): + Whether to generate and log completions from both the model and the reference model to W&B or Comet during + evaluation. """ - learning_rate: float = field( - default=1e-6, + # Parameters that control the model and reference model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " - "`transformers.TrainingArguments`." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `DPOTrainer` is provided as a string." }, ) - beta: float = field( - default=0.1, + ref_model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, metadata={ - "help": "Parameter controlling the deviation from the reference model. " - "Higher β means less deviation from the reference model." + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument " + "of the `DPOTrainer` is provided as a string." }, ) - label_smoothing: float = field( - default=0.0, - metadata={"help": "Label smoothing factor."}, + model_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, ) - loss_type: str = field( - default="sigmoid", + ref_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, + ) + force_use_ref_model: bool = field( + default=False, metadata={ - "help": "Type of loss to use.", - "choices": [ - "sigmoid", - "hinge", - "ipo", - "exo_pair", - "nca_pair", - "robust", - "bco_pair", - "sppo_hard", - "aot", - "aot_pair", - "discopop", - "apo_zero", - "apo_down", - ], + "help": "If you provide a PEFT model as the active model and wish to use a different model for the " + "`ref_model`, set this flag to `True`." }, ) - use_weighting: bool = field( - default=False, - metadata={"help": "Whether to weight the loss as done in the WPO paper."}, + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, ) - label_pad_token_id: int = field( - default=-100, + use_num_logits_to_keep: bool = field( + default=False, metadata={ - "help": "Label pad token id. This argument is required if you want to use the default data collator." + "help": "If `True`, only a specified number of logits are computed in the forward pass. This can be " + "useful for saving memory and speeding up training by not computing the logits for all tokens, especially " + "in scenarios when working with very long prompts where labels are ignored (-100)." }, ) + + # Parameters that control the data preprocessing + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) padding_value: Optional[int] = field( default=None, metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, ) + label_pad_token_id: int = field( + default=-100, + metadata={"help": "Padding value to use for labels."}, + ) truncation_mode: str = field( default="keep_end", metadata={ - "help": "Truncation mode to use when the prompt is too long. This argument is required if you want to use " - "the default data collator.", + "help": "Truncation mode to use when the prompt is too long.", "choices": ["keep_end", "keep_start"], }, ) - max_length: Optional[int] = field( - default=None, - metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, - ) max_prompt_length: Optional[int] = field( default=None, - metadata={ - "help": "Maximum length of the prompt. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, + metadata={"help": "Maximum length of the prompt."}, ) max_completion_length: Optional[int] = field( default=None, - metadata={ - "help": "Maximum length of the completion. This argument is required if you want to use the default data " - "collator and your model is an encoder-decoder." - }, + metadata={"help": "Maximum length of the completion."}, ) - is_encoder_decoder: Optional[bool] = field( + max_length: Optional[int] = field( default=None, - metadata={ - "help": "When using the `model_init` argument (callable) to instantiate the model instead of the " - "`model` argument, you need to specify if the model returned by the callable is an encoder-decoder model." - }, + metadata={"help": "Maximum length of the full sequence (prompt + completion)."}, ) - disable_dropout: bool = field( - default=True, - metadata={"help": "Whether to disable dropout in the model and reference model."}, - ) - generate_during_eval: bool = field( + padding_free: bool = field( default=False, metadata={ - "help": "If `True`, generates and logs completions from both the model and the reference model " - "to W&B during evaluation." + "help": "Whether forward passes are performed without padding by flattening all sequences in the batch " + "into a single continuous sequence. This approach requires associating a `position_ids` vector to track " + "positional information. Currently, this is only supported with the `flash_attention_2` mechanism, as it " + "can handle the flattened batch structure." }, ) precompute_ref_log_probs: bool = field( default=False, metadata={ - "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " - "This is useful when training without the reference model to reduce the total GPU memory needed." + "help": "Whether to precompute the log probabilities from the reference model. Setting this to `True` " + "allows training without needing the reference model during training, which can help reduce GPU memory " + "usage. If set to `False` (default), the reference model will be used during training to compute log " + "probabilities on-the-fly." }, ) precompute_ref_batch_size: Optional[int] = field( @@ -264,44 +261,41 @@ class DPOConfig(TrainingArguments): "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." }, ) - dataset_num_proc: Optional[int] = field( - default=None, - metadata={"help": "Number of processes to use for processing the dataset."}, - ) - model_init_kwargs: Optional[dict[str, Any]] = field( - default=None, - metadata={ - "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " - "model from a string." - }, - ) - ref_model_init_kwargs: Optional[dict[str, Any]] = field( - default=None, + + # Parameters that control the training + learning_rate: float = field( + default=1e-6, metadata={ - "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " - "reference model from a string." + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`transformers.TrainingArguments`." }, ) - model_adapter_name: Optional[str] = field( - default=None, - metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, - ) - ref_adapter_name: Optional[str] = field( - default=None, - metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, - ) - reference_free: bool = field( - default=False, + loss_type: str = field( + default="sigmoid", metadata={ - "help": "If `True`, we ignore the _provided_ reference model and implicitly use a reference model that " - "assigns equal probability to all responses." + "help": "Type of loss to use.", + "choices": [ + "sigmoid", + "hinge", + "ipo", + "exo_pair", + "nca_pair", + "robust", + "bco_pair", + "sppo_hard", + "aot", + "aot_pair", + "discopop", + "apo_zero", + "apo_down", + ], }, ) - force_use_ref_model: bool = field( - default=False, + beta: float = field( + default=0.1, metadata={ - "help": "In case one passes a PEFT model for the active model and you want to use a different model for " - "the ref_model, set this flag to `True`." + "help": "Parameter controlling the deviation from the reference model. " + "Higher β means less deviation from the reference model." }, ) f_divergence_type: FDivergenceType = field( @@ -315,27 +309,23 @@ class DPOConfig(TrainingArguments): default=1.0, metadata={"help": "α coefficient in the α-divergence u^-α regularization function for DPO loss."}, ) - sync_ref_model: bool = field( + reference_free: bool = field( default=False, metadata={ - "help": "When set to `True`, the reference model is synchronized with the active model every " - "`ref_model_sync_steps` steps, using the `ref_model_mixup_alpha` parameter." + "help": "Whether to ignore the provided reference model and implicitly use a reference model that assigns " + "equal probability to all responses." }, ) - ref_model_mixup_alpha: float = field( - default=0.9, + label_smoothing: float = field( + default=0.0, metadata={ - "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " - "previous reference policy during updates. The reference policy is updated according to the equation: " - "`π_ref = α * π_θ + (1 - α) * π_ref_prev`" + "help": "Robust DPO label smoothing parameter from the cDPO report and Robust DPO paper that should " + "be between `0.0` and `0.5`." }, ) - ref_model_sync_steps: int = field( - default=64, - metadata={ - "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " - "synchronized with the reference policy." - }, + use_weighting: bool = field( + default=False, + metadata={"help": "Whether to weight the loss as done in the WPO paper."}, ) rpo_alpha: Optional[float] = field( default=None, @@ -352,18 +342,49 @@ class DPOConfig(TrainingArguments): "loss. The paper recommends the default value `discopop_tau=0.05`." }, ) - use_num_logits_to_keep: bool = field( + sync_ref_model: bool = field( default=False, metadata={ - "help": "If `True`, only a specified number of logits are computed in the forward pass of CausalLM. " - "This can be useful for saving memory and speeding up training by not computing the logits for all " - "tokens, especially in scenarios when working with very long prompts where labels are ignored (-100)." + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." }, ) - padding_free: bool = field( + ref_model_mixup_alpha: float = field( + default=0.9, + metadata={ + "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=64, + metadata={ + "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + + # Parameters that control the logging + generate_during_eval: bool = field( default=False, metadata={ - "help": "Whether the forward passes are performed without padding, i.e. flattening all the samples in the " - "batch into a single sample, associated with a position_ids vector. Only possible with flash-attention." + "help": "Whether to generate and log completions from both the model and the reference model to W&B or " + "Comet during evaluation." }, ) + + # Deprecated parameters + is_encoder_decoder: Optional[bool] = field( + default=None, + metadata={"help": "Deprecated. This argument is not used anymore."}, + ) + + def __post_init__(self): + if self.is_encoder_decoder is not None: + warnings.warn( + "The `is_encoder_decoder` parameter is deprecated will be removed in version 0.15. The trainer now " + "automatically determines if the model is an encoder-decoder, so you can safely remove it." + ) + + return super().__post_init__() diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3774c84b21..b9adbedf5f 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -372,9 +372,10 @@ def make_inputs_require_grad(module, input, output): self.padding_value = processing_class.tokenizer.pad_token_id else: raise ValueError( - "Can't find `pad_token_id` in the `processing_class`. " - "Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`) " - "before instantiating the trainer." + "`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in the " + "`processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set " + "`tokenizer.pad_token` (e.g., `tokenizer.pad_token = tokenizer.eos_token`) before instantiating " + "the trainer." ) if data_collator is None: @@ -1366,7 +1367,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) # if ref_output in batch use that otherwise use the reference model @@ -1380,7 +1381,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) else: ref_output = self.ref_model.generate( @@ -1388,13 +1389,13 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - ref_output = pad_to_length(ref_output, self.max_length, self.processing_class.pad_token_id) + ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) return policy_output_decoded, ref_output_decoded