diff --git a/tests/fixtures/configs/train/lddpo/base.json b/tests/fixtures/configs/train/lddpo/base.json new file mode 100644 index 0000000..43eae2a --- /dev/null +++ b/tests/fixtures/configs/train/lddpo/base.json @@ -0,0 +1,146 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "rm_preferences_test", + "records_path": "tests/fixtures/datasets/rm/train_preferences.jsonl", + "sample_rate": 1 + } + ], + "chat_settings":{ + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "max_tokens_count": 120 + }, + "add_labels": true, + "dataset_type": "pair_preferences" + }, + "val_dataset_settings": { + "sources": [ + { + "name": "rm_preferences_test", + "records_path": "tests/fixtures/datasets/rm/val_preferences.jsonl", + "sample_rate": 1 + } + ], + "chat_settings":{ + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "max_tokens_count": 120 + }, + "add_labels": true, + "dataset_type": "pair_preferences" + }, + "model_settings": { + "model_path": "tests/fixtures/models/llama2_tiny", + "model_type": "causal", + "transformers_settings": {}, + "adapter_path": "tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/trainer", + "is_trainable": true + }, + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": false, + "stop_strings": "", + "max_new_tokens": 8 + }, + "custom_generation_settings": { + "skip_special_tokens": false + }, + "dataset_settings": { + "sources": [ + { + "name": "chat_test", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "num_samples": 2 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "dataset_type": "chat", + "max_tokens_count": 150, + "only_answer_loss": true + }, + "metric_settings": [ + { + "type": "length", + "parameters": {"need_average": [true]} + }, + { + "type": "kl", + "parameters": { + "need_average": [true], + "ref_logits_type": "sft" + } + }, + { + "type": "kl", + "parameters": { + "need_average": [true], + "ref_logits_type": "reference" + } + } + + ] + }, + "tokenizer_settings": {}, + "special_tokens_settings": { + "bos_token": "", + "eos_token": "", + "pad_token": "" + }, + "trainer_settings": { + "lc_alpha": 0.5, + "evaluation_strategy": "steps", + "per_device_train_batch_size": 2, + "per_device_eval_batch_size": 2, + "gradient_accumulation_steps": 2, + "eval_steps": 4, + "save_steps": 4, + "logging_steps": 1, + "learning_rate": 0.0003, + "num_train_epochs": 2, + "lr_scheduler_type": "cosine", + "warmup_steps": 2, + "fp16": false, + "bf16": false, + "optim": "adamw_torch", + "save_total_limit": 1, + "loss_settings": { + "loss_type": "ipo" + }, + "sync_ref_settings": { + "sync_ref_model": true + }, + "use_sft_model": true, + "no_cuda": true + }, + "logging_settings": { + "project_name": "alignment", + "run_name": "dpo", + "entity": "turbo-alignment" + }, + "log_path": "test_dpo_llama_train_output" +} diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index da109ff..8bd7ec6 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -59,6 +59,19 @@ def train_ddpo_entrypoint( pipelines.TrainDDPOStrategy().run(experiment_settings) +@app.command(name='train_lddpo', help='Run LDDPO pipeline') +def train_lddpo_entrypoint( + experiment_settings_path: Path = typer.Option( + ..., + '--experiment_settings_path', + exists=True, + help='Path to experiment config file', + ) +) -> None: + experiment_settings = pipeline_settings.LDDPOTrainExperimentSettings.parse_file(experiment_settings_path) + pipelines.TrainLDDPOStrategy().run(experiment_settings) + + @app.command(name='train_rm', help='Run RM pipeline') def train_rm_entrypoint( experiment_settings_path: Path = typer.Option( diff --git a/turbo_alignment/pipelines/train/__init__.py b/turbo_alignment/pipelines/train/__init__.py index ac24318..8836280 100755 --- a/turbo_alignment/pipelines/train/__init__.py +++ b/turbo_alignment/pipelines/train/__init__.py @@ -2,6 +2,7 @@ from .ddpo import TrainDDPOStrategy from .dpo import TrainDPOStrategy from .kto import TrainKTOStrategy +from .lddpo import TrainLDDPOStrategy from .multimodal import TrainMultimodalStrategy from .rag import TrainRAGStrategy from .rm import TrainRMStrategy diff --git a/turbo_alignment/pipelines/train/lddpo.py b/turbo_alignment/pipelines/train/lddpo.py new file mode 100644 index 0000000..7eb3d51 --- /dev/null +++ b/turbo_alignment/pipelines/train/lddpo.py @@ -0,0 +1,123 @@ +from typing import Callable + +from torch.utils.data import Dataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.data.data_collator import DataCollatorMixin + +from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback +from turbo_alignment.common.logging import get_project_logger +from turbo_alignment.common.tf.loaders.model import load_model +from turbo_alignment.constants import TRAINER_LOGS_FOLDER +from turbo_alignment.dataset.chat.chat import InferenceChatDataset +from turbo_alignment.dataset.loader import DatasetLoader +from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator +from turbo_alignment.metrics.metric import Metric +from turbo_alignment.metrics.registry import MetricSettingsRegistry +from turbo_alignment.pipelines.train.base import BaseTrainStrategy +from turbo_alignment.settings.datasets.base import DatasetStrategy +from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings +from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings +from turbo_alignment.trainers.lddpo import LDDPOTrainer, LDDPOTrainingArguments + +logger = get_project_logger() + + +class TrainLDDPOStrategy(BaseTrainStrategy[LDDPOTrainExperimentSettings]): + @staticmethod + def _get_data_collator( + experiment_settings: LDDPOTrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> Callable: + return PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=True) + + @staticmethod + def _get_cherry_pick_callback( + experiment_settings: LDDPOTrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> ChatCherryPickCallback: + return None + # cherry_pick_settings = experiment_settings.cherry_pick_settings + + # cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( + # cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE + # ) + + # metrics = [ + # Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) + # for metric in cherry_pick_settings.metric_settings + # ] + + # return ChatCherryPickCallback( + # cherry_pick_settings=cherry_pick_settings, + # datasets=cherry_pick_datasets, + # metrics=metrics, + # ) + + def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: + logger.info(f'Train sample example:\n{dataset[0]}') + + logger.info( + 'Input-w check: {input_ids}'.format( + input_ids=collator([dataset[0], dataset[1]])['inputs_w']['input_ids'][0] + ) + ) + logger.info( + 'Mask-w check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_w']['attention_mask'][0]) + ) + logger.info( + 'Input-l check: {input_ids}'.format( + input_ids=collator([dataset[0], dataset[1]])['inputs_l']['input_ids'][0] + ) + ) + logger.info( + 'Mask-l check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_l']['attention_mask'][0]) + ) + + @staticmethod + def _get_training_args(experiment_settings: LDDPOTrainExperimentSettings) -> LDDPOTrainingArguments: + return LDDPOTrainingArguments( + output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), + label_names=[], + remove_unused_columns=False, + **experiment_settings.trainer_settings.dict(), + ) + + @staticmethod + def _get_trainer( + training_args: LDDPOTrainingArguments, + experiment_settings: DPOTrainExperimentSettings, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + val_dataset: Dataset, + data_collator: Callable, + ): + model.config.use_cache = not training_args.gradient_checkpointing + + extra_args = {} + if experiment_settings.trainer_settings.use_ref_model: + ref_model = load_model(experiment_settings.model_settings, tokenizer) + for _, param in ref_model.named_parameters(): + param.requires_grad = False + + extra_args['ref_model'] = ref_model + + if experiment_settings.trainer_settings.use_sft_model: + sft_model = load_model(experiment_settings.model_settings, tokenizer) + for _, param in sft_model.named_parameters(): + param.requires_grad = False + + extra_args['sft_model'] = sft_model + + return LDDPOTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + callbacks=[], + data_collator=data_collator, + tokenizer=tokenizer, + **extra_args, + ) diff --git a/turbo_alignment/settings/pipelines/__init__.py b/turbo_alignment/settings/pipelines/__init__.py index 1b6143a..f3dece9 100755 --- a/turbo_alignment/settings/pipelines/__init__.py +++ b/turbo_alignment/settings/pipelines/__init__.py @@ -10,6 +10,7 @@ from turbo_alignment.settings.pipelines.train.ddpo import DDPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings +from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.multimodal import ( MultimodalTrainExperimentSettings, ) diff --git a/turbo_alignment/settings/pipelines/train/lddpo.py b/turbo_alignment/settings/pipelines/train/lddpo.py new file mode 100644 index 0000000..c733d58 --- /dev/null +++ b/turbo_alignment/settings/pipelines/train/lddpo.py @@ -0,0 +1,12 @@ +from turbo_alignment.settings.pipelines.train.dpo import ( + DPOTrainerSettings, + DPOTrainExperimentSettings, +) + + +class LDDPOTrainerSettings(DPOTrainerSettings): + lc_alpha: float + + +class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings): + trainer_settings: LDDPOTrainerSettings diff --git a/turbo_alignment/trainers/__init__.py b/turbo_alignment/trainers/__init__.py index 0c170c5..b87be8d 100755 --- a/turbo_alignment/trainers/__init__.py +++ b/turbo_alignment/trainers/__init__.py @@ -3,6 +3,7 @@ from .ddpo import DDPOTrainer from .dpo import DPOTrainer from .kto import KTOTrainer +from .lddpo import LDDPOTrainer from .multigpu import MultiGPUCherryPicksTrainer from .multimodal import MultimodalTrainer from .rm import RMTrainer diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index ef3288f..94dd7b6 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -606,12 +606,18 @@ def _get_batch_logps( return (per_token_logps * loss_mask).sum(-1) def concatenated_forward( - self, model: nn.Module, batch: dict[str, Any] + self, + model: nn.Module, + batch: dict[str, Any], + get_from_dataset=False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None) + if get_from_dataset: + return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps') + all_logits = model( concatenated_batch['input_ids'], attention_mask=concatenated_batch['attention_mask'], @@ -632,16 +638,19 @@ def concatenated_forward( return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: - with torch.no_grad(): - if model is not None: - (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ( - chosen_logps, - rejected_logps, - *_, - ) = self.concatenated_forward(self.model, batch) + if self.ref_model is None: + chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True) + else: + with torch.no_grad(): + if model is not None: + (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ( + chosen_logps, + rejected_logps, + *_, + ) = self.concatenated_forward(self.model, batch) return chosen_logps, rejected_logps @@ -780,9 +789,9 @@ def _compute_metrics( metrics[f'{prefix_name}grad_term'] = ( (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item() ) - metrics[f'{prefix_name}grad_term_std'] = ( - (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item() - ) + # metrics[f'{prefix_name}grad_term_std'] = ( + # (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item() + # ) return metrics @@ -822,6 +831,7 @@ def compute_loss( model: PreTrainedModel | nn.Module, inputs: dict[str, torch.Tensor | Any], return_outputs=False, + num_items_in_batch=None, ) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]: loss, metrics = self.get_batch_metrics(model, inputs, train_eval='train') diff --git a/turbo_alignment/trainers/lddpo.py b/turbo_alignment/trainers/lddpo.py new file mode 100644 index 0000000..ab97561 --- /dev/null +++ b/turbo_alignment/trainers/lddpo.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import Dataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback + +from turbo_alignment.constants import DISABLE_LOSS_LABEL +from turbo_alignment.trainers.dpo import DPOTrainer, DPOTrainingArguments +from turbo_alignment.trainers.utils import concatenated_inputs + + +@dataclass +class LDDPOTrainingArguments(DPOTrainingArguments): + lc_alpha: float = 0.1 + + +class LDDPOTrainer(DPOTrainer): + """ + From https://arxiv.org/pdf/2409.06411 + + """ + + def __init__( + self, + model: PreTrainedModel | nn.Module, + data_collator: Callable, + args: LDDPOTrainingArguments, + train_dataset: Dataset, + eval_dataset: Dataset, + tokenizer: PreTrainedTokenizerBase | None = None, + callbacks: list[TrainerCallback] | None = None, + **kwargs, + ): + self.lc_alpha = args.lc_alpha + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + callbacks=callbacks, + **kwargs, + ) + + def _get_batch_logps(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if logits.shape[:-1] != labels.shape: + raise ValueError('Logits (batch and sequence length dim) and labels must have the same shape.') + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != DISABLE_LOSS_LABEL + + labels[labels == DISABLE_LOSS_LABEL] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + return (per_token_logps * loss_mask), loss_mask + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Any] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) + + precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None) + + all_logits = model( + concatenated_batch['input_ids'], + attention_mask=concatenated_batch['attention_mask'], + ).logits.to(torch.float32) + + all_logps, loss_mask = self._get_batch_logps(all_logits, concatenated_batch['labels']) + + batch_size = concatenated_batch['input_ids'].size(0) // 2 + chosen_mask, rejected_mask = loss_mask.split(batch_size, dim=0) + + public_ = chosen_mask * rejected_mask + public_mask = torch.cat([public_, public_]) + + public_logps = all_logps * public_mask + + all_logps = self.lc_alpha * all_logps + (1 - self.lc_alpha) * public_logps + + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) + + return chosen_logps.sum(-1), rejected_logps.sum(-1), chosen_logits, rejected_logits, precomputed_margins