diff --git a/tests/fixtures/configs/train/lddpo/base.json b/tests/fixtures/configs/train/lddpo/base.json new file mode 100644 index 0000000..3140625 --- /dev/null +++ b/tests/fixtures/configs/train/lddpo/base.json @@ -0,0 +1,145 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "rm_preferences_test", + "records_path": "tests/fixtures/datasets/rm/train_preferences.jsonl", + "sample_rate": 1 + } + ], + "chat_settings":{ + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "max_tokens_count": 120 + }, + "add_labels": true, + "dataset_type": "pair_preferences" + }, + "val_dataset_settings": { + "sources": [ + { + "name": "rm_preferences_test", + "records_path": "tests/fixtures/datasets/rm/val_preferences.jsonl", + "sample_rate": 1 + } + ], + "chat_settings":{ + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "max_tokens_count": 120 + }, + "add_labels": true, + "dataset_type": "pair_preferences" + }, + "model_settings": { + "model_path": "tests/fixtures/models/llama2_tiny", + "model_type": "causal", + "transformers_settings": {}, + "adapter_path": "tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/trainer", + "is_trainable": true + }, + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": false, + "stop_strings": "", + "max_new_tokens": 8 + }, + "custom_generation_settings": { + "skip_special_tokens": false + }, + "dataset_settings": { + "sources": [ + { + "name": "chat_test", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "num_samples": 2 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "dataset_type": "chat", + "max_tokens_count": 150, + "only_answer_loss": true + }, + "metric_settings": [ + { + "type": "length", + "parameters": {"need_average": [true]} + }, + { + "type": "kl", + "parameters": { + "need_average": [true], + "ref_logits_type": "sft" + } + }, + { + "type": "kl", + "parameters": { + "need_average": [true], + "ref_logits_type": "reference" + } + } + + ] + }, + "tokenizer_settings": {}, + "special_tokens_settings": { + "bos_token": "", + "eos_token": "", + "pad_token": "" + }, + "trainer_settings": { + "evaluation_strategy": "steps", + "per_device_train_batch_size": 2, + "per_device_eval_batch_size": 2, + "gradient_accumulation_steps": 2, + "eval_steps": 4, + "save_steps": 4, + "logging_steps": 1, + "learning_rate": 0.0003, + "num_train_epochs": 2, + "lr_scheduler_type": "cosine", + "warmup_steps": 2, + "fp16": false, + "bf16": false, + "optim": "adamw_torch", + "save_total_limit": 1, + "loss_settings": { + "loss_type": "ipo" + }, + "sync_ref_settings": { + "sync_ref_model": true + }, + "use_sft_model": true, + "no_cuda": true + }, + "logging_settings": { + "project_name": "alignment", + "run_name": "dpo", + "entity": "turbo-alignment" + }, + "log_path": "test_dpo_llama_train_output" +} diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index da109ff..8bd7ec6 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -59,6 +59,19 @@ def train_ddpo_entrypoint( pipelines.TrainDDPOStrategy().run(experiment_settings) +@app.command(name='train_lddpo', help='Run LDDPO pipeline') +def train_lddpo_entrypoint( + experiment_settings_path: Path = typer.Option( + ..., + '--experiment_settings_path', + exists=True, + help='Path to experiment config file', + ) +) -> None: + experiment_settings = pipeline_settings.LDDPOTrainExperimentSettings.parse_file(experiment_settings_path) + pipelines.TrainLDDPOStrategy().run(experiment_settings) + + @app.command(name='train_rm', help='Run RM pipeline') def train_rm_entrypoint( experiment_settings_path: Path = typer.Option( diff --git a/turbo_alignment/pipelines/train/__init__.py b/turbo_alignment/pipelines/train/__init__.py index ac24318..8836280 100755 --- a/turbo_alignment/pipelines/train/__init__.py +++ b/turbo_alignment/pipelines/train/__init__.py @@ -2,6 +2,7 @@ from .ddpo import TrainDDPOStrategy from .dpo import TrainDPOStrategy from .kto import TrainKTOStrategy +from .lddpo import TrainLDDPOStrategy from .multimodal import TrainMultimodalStrategy from .rag import TrainRAGStrategy from .rm import TrainRMStrategy diff --git a/turbo_alignment/pipelines/train/lddpo.py b/turbo_alignment/pipelines/train/lddpo.py new file mode 100644 index 0000000..611477f --- /dev/null +++ b/turbo_alignment/pipelines/train/lddpo.py @@ -0,0 +1,122 @@ +from typing import Callable + +from torch.utils.data import Dataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.data.data_collator import DataCollatorMixin + +from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback +from turbo_alignment.common.logging import get_project_logger +from turbo_alignment.common.tf.loaders.model import load_model +from turbo_alignment.constants import TRAINER_LOGS_FOLDER +from turbo_alignment.dataset.chat.chat import InferenceChatDataset +from turbo_alignment.dataset.loader import DatasetLoader +from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator +from turbo_alignment.metrics.metric import Metric +from turbo_alignment.metrics.registry import MetricSettingsRegistry +from turbo_alignment.pipelines.train.base import BaseTrainStrategy +from turbo_alignment.settings.datasets.base import DatasetStrategy +from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings +from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings +from turbo_alignment.trainers.lddpo import LDDPOTrainer, LDDPOTrainingArguments + +logger = get_project_logger() + + +class TrainLDDPOStrategy(BaseTrainStrategy[LDDPOTrainExperimentSettings]): + @staticmethod + def _get_data_collator( + experiment_settings: LDDPOTrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> Callable: + return PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=True) + + @staticmethod + def _get_cherry_pick_callback( + experiment_settings: LDDPOTrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> ChatCherryPickCallback: + cherry_pick_settings = experiment_settings.cherry_pick_settings + + cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( + cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE + ) + + metrics = [ + Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) + for metric in cherry_pick_settings.metric_settings + ] + + return ChatCherryPickCallback( + cherry_pick_settings=cherry_pick_settings, + datasets=cherry_pick_datasets, + metrics=metrics, + ) + + def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: + logger.info(f'Train sample example:\n{dataset[0]}') + + logger.info( + 'Input-w check: {input_ids}'.format( + input_ids=collator([dataset[0], dataset[1]])['inputs_w']['input_ids'][0] + ) + ) + logger.info( + 'Mask-w check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_w']['attention_mask'][0]) + ) + logger.info( + 'Input-l check: {input_ids}'.format( + input_ids=collator([dataset[0], dataset[1]])['inputs_l']['input_ids'][0] + ) + ) + logger.info( + 'Mask-l check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_l']['attention_mask'][0]) + ) + + @staticmethod + def _get_training_args(experiment_settings: LDDPOTrainExperimentSettings) -> LDDPOTrainingArguments: + return LDDPOTrainingArguments( + output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), + label_names=[], + remove_unused_columns=False, + **experiment_settings.trainer_settings.dict(), + ) + + @staticmethod + def _get_trainer( + training_args: LDDPOTrainingArguments, + experiment_settings: DPOTrainExperimentSettings, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + val_dataset: Dataset, + data_collator: Callable, + ): + model.config.use_cache = not training_args.gradient_checkpointing + + extra_args = {} + if experiment_settings.trainer_settings.use_ref_model: + ref_model = load_model(experiment_settings.model_settings, tokenizer) + for _, param in ref_model.named_parameters(): + param.requires_grad = False + + extra_args['ref_model'] = ref_model + + if experiment_settings.trainer_settings.use_sft_model: + sft_model = load_model(experiment_settings.model_settings, tokenizer) + for _, param in sft_model.named_parameters(): + param.requires_grad = False + + extra_args['sft_model'] = sft_model + + return LDDPOTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + callbacks=[], + data_collator=data_collator, + tokenizer=tokenizer, + **extra_args, + ) diff --git a/turbo_alignment/settings/pipelines/__init__.py b/turbo_alignment/settings/pipelines/__init__.py index 1b6143a..f3dece9 100755 --- a/turbo_alignment/settings/pipelines/__init__.py +++ b/turbo_alignment/settings/pipelines/__init__.py @@ -10,6 +10,7 @@ from turbo_alignment.settings.pipelines.train.ddpo import DDPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings +from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.multimodal import ( MultimodalTrainExperimentSettings, ) diff --git a/turbo_alignment/settings/pipelines/train/lddpo.py b/turbo_alignment/settings/pipelines/train/lddpo.py new file mode 100644 index 0000000..3dd1ccb --- /dev/null +++ b/turbo_alignment/settings/pipelines/train/lddpo.py @@ -0,0 +1,12 @@ +from turbo_alignment.settings.pipelines.train.dpo import ( + DPOTrainerSettings, + DPOTrainExperimentSettings, +) + + +class LDDPOTrainerSettings(DPOTrainerSettings): + alpha: float = 1.0 + + +class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings): + trainer_settings: LDDPOTrainerSettings diff --git a/turbo_alignment/trainers/__init__.py b/turbo_alignment/trainers/__init__.py index 0c170c5..b87be8d 100755 --- a/turbo_alignment/trainers/__init__.py +++ b/turbo_alignment/trainers/__init__.py @@ -3,6 +3,7 @@ from .ddpo import DDPOTrainer from .dpo import DPOTrainer from .kto import KTOTrainer +from .lddpo import LDDPOTrainer from .multigpu import MultiGPUCherryPicksTrainer from .multimodal import MultimodalTrainer from .rm import RMTrainer diff --git a/turbo_alignment/trainers/lddpo.py b/turbo_alignment/trainers/lddpo.py new file mode 100644 index 0000000..406b097 --- /dev/null +++ b/turbo_alignment/trainers/lddpo.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import Dataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback + +from turbo_alignment.constants import DISABLE_LOSS_LABEL +from turbo_alignment.trainers.dpo import DPOTrainer, DPOTrainingArguments +from turbo_alignment.trainers.utils import concatenated_inputs + + +@dataclass +class LDDPOTrainingArguments(DPOTrainingArguments): + alpha: float = 0.1 + + +class LDDPOTrainer(DPOTrainer): + """ + From https://arxiv.org/pdf/2409.06411 + + """ + + def __init__( + self, + model: PreTrainedModel | nn.Module, + data_collator: Callable, + args: LDDPOTrainingArguments, + train_dataset: Dataset, + eval_dataset: Dataset, + tokenizer: PreTrainedTokenizerBase | None = None, + callbacks: list[TrainerCallback] | None = None, + **kwargs, + ): + self.alpha = args.alpha + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + callbacks=callbacks, + **kwargs, + ) + + def _get_batch_logps(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if logits.shape[:-1] != labels.shape: + raise ValueError('Logits (batch and sequence length dim) and labels must have the same shape.') + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + + labels[labels == DISABLE_LOSS_LABEL] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + return per_token_logps + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Any] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) + + precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None) + + all_logits = model( + concatenated_batch['input_ids'], + attention_mask=concatenated_batch['attention_mask'], + ).logits.to(torch.float32) + + loss_mask = concatenated_batch['labels'][:, 1:] != DISABLE_LOSS_LABEL + + per_token_logps = self._get_batch_logps(all_logits, concatenated_batch['labels']) + chosen_idxs = batch['inputs_w']['input_ids'].shape[0] + rejected_idx = batch['inputs_l']['input_ids'].shape[0] + + chosen_logits = all_logits[:chosen_idxs] + rejected_logits = all_logits[chosen_idxs:] + + chosen_per_token_logps = per_token_logps[:chosen_idxs] + rejected_per_token_logps = per_token_logps[chosen_idxs : chosen_idxs + rejected_idx] + + chosen_loss_mask = loss_mask[:chosen_idxs] + rejected_loss_mask = loss_mask[chosen_idxs : chosen_idxs + rejected_idx] + + min_lengths = torch.min(chosen_loss_mask.sum(-1), rejected_loss_mask.sum(-1)) + + answer_start_idx = torch.argmax( + chosen_loss_mask.int(), -1 + ) # The index of the beginning of the chosen and rejected + split_start_idx = answer_start_idx + min_lengths # Add the length of the shorter answer + + # Setting the increment mask + alpha_mask = torch.arange(torch.full_like(chosen_loss_mask, 1).size(1)).unsqueeze( + 0 + ) >= split_start_idx.unsqueeze(1) + + # Incrementing by alpha logprobs that are out of bounds + chosen_per_token_logps[alpha_mask] = chosen_per_token_logps[alpha_mask] ** self.alpha + rejected_per_token_logps[alpha_mask] = rejected_per_token_logps[alpha_mask] ** self.alpha + + if self.average_log_prob: + chosen_logps = (chosen_per_token_logps * chosen_loss_mask).sum(-1) / chosen_loss_mask.sum(-1) + rejected_logps = (rejected_per_token_logps * rejected_loss_mask).sum(-1) / rejected_loss_mask.sum(-1) + chosen_logps = (chosen_per_token_logps * chosen_loss_mask).sum(-1) + rejected_logps = (rejected_per_token_logps * rejected_loss_mask).sum(-1) + + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins