From 20b97d10d16b8c0fb25df3a41d722f1e159f0c43 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Wed, 27 Nov 2024 07:04:36 +0000 Subject: [PATCH 1/6] ld-dpo trainer --- tests/fixtures/configs/train/lddpo/base.json | 145 ++++++++++++++++++ turbo_alignment/cli/train.py | 13 ++ turbo_alignment/pipelines/train/__init__.py | 1 + turbo_alignment/pipelines/train/lddpo.py | 122 +++++++++++++++ .../settings/pipelines/__init__.py | 1 + .../settings/pipelines/train/lddpo.py | 12 ++ turbo_alignment/trainers/__init__.py | 1 + turbo_alignment/trainers/lddpo.py | 111 ++++++++++++++ 8 files changed, 406 insertions(+) create mode 100644 tests/fixtures/configs/train/lddpo/base.json create mode 100644 turbo_alignment/pipelines/train/lddpo.py create mode 100644 turbo_alignment/settings/pipelines/train/lddpo.py create mode 100644 turbo_alignment/trainers/lddpo.py diff --git a/tests/fixtures/configs/train/lddpo/base.json b/tests/fixtures/configs/train/lddpo/base.json new file mode 100644 index 0000000..3140625 --- /dev/null +++ b/tests/fixtures/configs/train/lddpo/base.json @@ -0,0 +1,145 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "rm_preferences_test", + "records_path": "tests/fixtures/datasets/rm/train_preferences.jsonl", + "sample_rate": 1 + } + ], + "chat_settings":{ + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "max_tokens_count": 120 + }, + "add_labels": true, + "dataset_type": "pair_preferences" + }, + "val_dataset_settings": { + "sources": [ + { + "name": "rm_preferences_test", + "records_path": "tests/fixtures/datasets/rm/val_preferences.jsonl", + "sample_rate": 1 + } + ], + "chat_settings":{ + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "max_tokens_count": 120 + }, + "add_labels": true, + "dataset_type": "pair_preferences" + }, + "model_settings": { + "model_path": "tests/fixtures/models/llama2_tiny", + "model_type": "causal", + "transformers_settings": {}, + "adapter_path": "tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/trainer", + "is_trainable": true + }, + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": false, + "stop_strings": "", + "max_new_tokens": 8 + }, + "custom_generation_settings": { + "skip_special_tokens": false + }, + "dataset_settings": { + "sources": [ + { + "name": "chat_test", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "num_samples": 2 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "dataset_type": "chat", + "max_tokens_count": 150, + "only_answer_loss": true + }, + "metric_settings": [ + { + "type": "length", + "parameters": {"need_average": [true]} + }, + { + "type": "kl", + "parameters": { + "need_average": [true], + "ref_logits_type": "sft" + } + }, + { + "type": "kl", + "parameters": { + "need_average": [true], + "ref_logits_type": "reference" + } + } + + ] + }, + "tokenizer_settings": {}, + "special_tokens_settings": { + "bos_token": "", + "eos_token": "", + "pad_token": "" + }, + "trainer_settings": { + "evaluation_strategy": "steps", + "per_device_train_batch_size": 2, + "per_device_eval_batch_size": 2, + "gradient_accumulation_steps": 2, + "eval_steps": 4, + "save_steps": 4, + "logging_steps": 1, + "learning_rate": 0.0003, + "num_train_epochs": 2, + "lr_scheduler_type": "cosine", + "warmup_steps": 2, + "fp16": false, + "bf16": false, + "optim": "adamw_torch", + "save_total_limit": 1, + "loss_settings": { + "loss_type": "ipo" + }, + "sync_ref_settings": { + "sync_ref_model": true + }, + "use_sft_model": true, + "no_cuda": true + }, + "logging_settings": { + "project_name": "alignment", + "run_name": "dpo", + "entity": "turbo-alignment" + }, + "log_path": "test_dpo_llama_train_output" +} diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index da109ff..8bd7ec6 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -59,6 +59,19 @@ def train_ddpo_entrypoint( pipelines.TrainDDPOStrategy().run(experiment_settings) +@app.command(name='train_lddpo', help='Run LDDPO pipeline') +def train_lddpo_entrypoint( + experiment_settings_path: Path = typer.Option( + ..., + '--experiment_settings_path', + exists=True, + help='Path to experiment config file', + ) +) -> None: + experiment_settings = pipeline_settings.LDDPOTrainExperimentSettings.parse_file(experiment_settings_path) + pipelines.TrainLDDPOStrategy().run(experiment_settings) + + @app.command(name='train_rm', help='Run RM pipeline') def train_rm_entrypoint( experiment_settings_path: Path = typer.Option( diff --git a/turbo_alignment/pipelines/train/__init__.py b/turbo_alignment/pipelines/train/__init__.py index ac24318..8836280 100755 --- a/turbo_alignment/pipelines/train/__init__.py +++ b/turbo_alignment/pipelines/train/__init__.py @@ -2,6 +2,7 @@ from .ddpo import TrainDDPOStrategy from .dpo import TrainDPOStrategy from .kto import TrainKTOStrategy +from .lddpo import TrainLDDPOStrategy from .multimodal import TrainMultimodalStrategy from .rag import TrainRAGStrategy from .rm import TrainRMStrategy diff --git a/turbo_alignment/pipelines/train/lddpo.py b/turbo_alignment/pipelines/train/lddpo.py new file mode 100644 index 0000000..611477f --- /dev/null +++ b/turbo_alignment/pipelines/train/lddpo.py @@ -0,0 +1,122 @@ +from typing import Callable + +from torch.utils.data import Dataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.data.data_collator import DataCollatorMixin + +from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback +from turbo_alignment.common.logging import get_project_logger +from turbo_alignment.common.tf.loaders.model import load_model +from turbo_alignment.constants import TRAINER_LOGS_FOLDER +from turbo_alignment.dataset.chat.chat import InferenceChatDataset +from turbo_alignment.dataset.loader import DatasetLoader +from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator +from turbo_alignment.metrics.metric import Metric +from turbo_alignment.metrics.registry import MetricSettingsRegistry +from turbo_alignment.pipelines.train.base import BaseTrainStrategy +from turbo_alignment.settings.datasets.base import DatasetStrategy +from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings +from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings +from turbo_alignment.trainers.lddpo import LDDPOTrainer, LDDPOTrainingArguments + +logger = get_project_logger() + + +class TrainLDDPOStrategy(BaseTrainStrategy[LDDPOTrainExperimentSettings]): + @staticmethod + def _get_data_collator( + experiment_settings: LDDPOTrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> Callable: + return PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=True) + + @staticmethod + def _get_cherry_pick_callback( + experiment_settings: LDDPOTrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> ChatCherryPickCallback: + cherry_pick_settings = experiment_settings.cherry_pick_settings + + cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( + cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE + ) + + metrics = [ + Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) + for metric in cherry_pick_settings.metric_settings + ] + + return ChatCherryPickCallback( + cherry_pick_settings=cherry_pick_settings, + datasets=cherry_pick_datasets, + metrics=metrics, + ) + + def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: + logger.info(f'Train sample example:\n{dataset[0]}') + + logger.info( + 'Input-w check: {input_ids}'.format( + input_ids=collator([dataset[0], dataset[1]])['inputs_w']['input_ids'][0] + ) + ) + logger.info( + 'Mask-w check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_w']['attention_mask'][0]) + ) + logger.info( + 'Input-l check: {input_ids}'.format( + input_ids=collator([dataset[0], dataset[1]])['inputs_l']['input_ids'][0] + ) + ) + logger.info( + 'Mask-l check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_l']['attention_mask'][0]) + ) + + @staticmethod + def _get_training_args(experiment_settings: LDDPOTrainExperimentSettings) -> LDDPOTrainingArguments: + return LDDPOTrainingArguments( + output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), + label_names=[], + remove_unused_columns=False, + **experiment_settings.trainer_settings.dict(), + ) + + @staticmethod + def _get_trainer( + training_args: LDDPOTrainingArguments, + experiment_settings: DPOTrainExperimentSettings, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + val_dataset: Dataset, + data_collator: Callable, + ): + model.config.use_cache = not training_args.gradient_checkpointing + + extra_args = {} + if experiment_settings.trainer_settings.use_ref_model: + ref_model = load_model(experiment_settings.model_settings, tokenizer) + for _, param in ref_model.named_parameters(): + param.requires_grad = False + + extra_args['ref_model'] = ref_model + + if experiment_settings.trainer_settings.use_sft_model: + sft_model = load_model(experiment_settings.model_settings, tokenizer) + for _, param in sft_model.named_parameters(): + param.requires_grad = False + + extra_args['sft_model'] = sft_model + + return LDDPOTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + callbacks=[], + data_collator=data_collator, + tokenizer=tokenizer, + **extra_args, + ) diff --git a/turbo_alignment/settings/pipelines/__init__.py b/turbo_alignment/settings/pipelines/__init__.py index 1b6143a..f3dece9 100755 --- a/turbo_alignment/settings/pipelines/__init__.py +++ b/turbo_alignment/settings/pipelines/__init__.py @@ -10,6 +10,7 @@ from turbo_alignment.settings.pipelines.train.ddpo import DDPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings +from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings from turbo_alignment.settings.pipelines.train.multimodal import ( MultimodalTrainExperimentSettings, ) diff --git a/turbo_alignment/settings/pipelines/train/lddpo.py b/turbo_alignment/settings/pipelines/train/lddpo.py new file mode 100644 index 0000000..3dd1ccb --- /dev/null +++ b/turbo_alignment/settings/pipelines/train/lddpo.py @@ -0,0 +1,12 @@ +from turbo_alignment.settings.pipelines.train.dpo import ( + DPOTrainerSettings, + DPOTrainExperimentSettings, +) + + +class LDDPOTrainerSettings(DPOTrainerSettings): + alpha: float = 1.0 + + +class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings): + trainer_settings: LDDPOTrainerSettings diff --git a/turbo_alignment/trainers/__init__.py b/turbo_alignment/trainers/__init__.py index 0c170c5..b87be8d 100755 --- a/turbo_alignment/trainers/__init__.py +++ b/turbo_alignment/trainers/__init__.py @@ -3,6 +3,7 @@ from .ddpo import DDPOTrainer from .dpo import DPOTrainer from .kto import KTOTrainer +from .lddpo import LDDPOTrainer from .multigpu import MultiGPUCherryPicksTrainer from .multimodal import MultimodalTrainer from .rm import RMTrainer diff --git a/turbo_alignment/trainers/lddpo.py b/turbo_alignment/trainers/lddpo.py new file mode 100644 index 0000000..406b097 --- /dev/null +++ b/turbo_alignment/trainers/lddpo.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import Dataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback + +from turbo_alignment.constants import DISABLE_LOSS_LABEL +from turbo_alignment.trainers.dpo import DPOTrainer, DPOTrainingArguments +from turbo_alignment.trainers.utils import concatenated_inputs + + +@dataclass +class LDDPOTrainingArguments(DPOTrainingArguments): + alpha: float = 0.1 + + +class LDDPOTrainer(DPOTrainer): + """ + From https://arxiv.org/pdf/2409.06411 + + """ + + def __init__( + self, + model: PreTrainedModel | nn.Module, + data_collator: Callable, + args: LDDPOTrainingArguments, + train_dataset: Dataset, + eval_dataset: Dataset, + tokenizer: PreTrainedTokenizerBase | None = None, + callbacks: list[TrainerCallback] | None = None, + **kwargs, + ): + self.alpha = args.alpha + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + callbacks=callbacks, + **kwargs, + ) + + def _get_batch_logps(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if logits.shape[:-1] != labels.shape: + raise ValueError('Logits (batch and sequence length dim) and labels must have the same shape.') + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + + labels[labels == DISABLE_LOSS_LABEL] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + return per_token_logps + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Any] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) + + precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None) + + all_logits = model( + concatenated_batch['input_ids'], + attention_mask=concatenated_batch['attention_mask'], + ).logits.to(torch.float32) + + loss_mask = concatenated_batch['labels'][:, 1:] != DISABLE_LOSS_LABEL + + per_token_logps = self._get_batch_logps(all_logits, concatenated_batch['labels']) + chosen_idxs = batch['inputs_w']['input_ids'].shape[0] + rejected_idx = batch['inputs_l']['input_ids'].shape[0] + + chosen_logits = all_logits[:chosen_idxs] + rejected_logits = all_logits[chosen_idxs:] + + chosen_per_token_logps = per_token_logps[:chosen_idxs] + rejected_per_token_logps = per_token_logps[chosen_idxs : chosen_idxs + rejected_idx] + + chosen_loss_mask = loss_mask[:chosen_idxs] + rejected_loss_mask = loss_mask[chosen_idxs : chosen_idxs + rejected_idx] + + min_lengths = torch.min(chosen_loss_mask.sum(-1), rejected_loss_mask.sum(-1)) + + answer_start_idx = torch.argmax( + chosen_loss_mask.int(), -1 + ) # The index of the beginning of the chosen and rejected + split_start_idx = answer_start_idx + min_lengths # Add the length of the shorter answer + + # Setting the increment mask + alpha_mask = torch.arange(torch.full_like(chosen_loss_mask, 1).size(1)).unsqueeze( + 0 + ) >= split_start_idx.unsqueeze(1) + + # Incrementing by alpha logprobs that are out of bounds + chosen_per_token_logps[alpha_mask] = chosen_per_token_logps[alpha_mask] ** self.alpha + rejected_per_token_logps[alpha_mask] = rejected_per_token_logps[alpha_mask] ** self.alpha + + if self.average_log_prob: + chosen_logps = (chosen_per_token_logps * chosen_loss_mask).sum(-1) / chosen_loss_mask.sum(-1) + rejected_logps = (rejected_per_token_logps * rejected_loss_mask).sum(-1) / rejected_loss_mask.sum(-1) + chosen_logps = (chosen_per_token_logps * chosen_loss_mask).sum(-1) + rejected_logps = (rejected_per_token_logps * rejected_loss_mask).sum(-1) + + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins From e083ef6d19cb0bad64d51c99816b78955d080cc6 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Wed, 27 Nov 2024 08:00:46 +0000 Subject: [PATCH 2/6] ld-dpo v2 (from openreview code --- .../settings/pipelines/train/lddpo.py | 2 +- turbo_alignment/trainers/lddpo.py | 45 +++++-------------- 2 files changed, 12 insertions(+), 35 deletions(-) diff --git a/turbo_alignment/settings/pipelines/train/lddpo.py b/turbo_alignment/settings/pipelines/train/lddpo.py index 3dd1ccb..5c953f8 100644 --- a/turbo_alignment/settings/pipelines/train/lddpo.py +++ b/turbo_alignment/settings/pipelines/train/lddpo.py @@ -5,7 +5,7 @@ class LDDPOTrainerSettings(DPOTrainerSettings): - alpha: float = 1.0 + lc_alpha: float = 1.0 class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings): diff --git a/turbo_alignment/trainers/lddpo.py b/turbo_alignment/trainers/lddpo.py index 406b097..45ccc57 100644 --- a/turbo_alignment/trainers/lddpo.py +++ b/turbo_alignment/trainers/lddpo.py @@ -14,7 +14,7 @@ @dataclass class LDDPOTrainingArguments(DPOTrainingArguments): - alpha: float = 0.1 + lc_alpha: float = 0.1 class LDDPOTrainer(DPOTrainer): @@ -34,7 +34,7 @@ def __init__( callbacks: list[TrainerCallback] | None = None, **kwargs, ): - self.alpha = args.alpha + self.lc_alpha = args.lc_alpha super().__init__( model=model, @@ -72,40 +72,17 @@ def concatenated_forward( ).logits.to(torch.float32) loss_mask = concatenated_batch['labels'][:, 1:] != DISABLE_LOSS_LABEL + batch_size = concatenated_batch['input_ids'].size(0) // 2 + chosen_mask, rejected_mask = loss_mask.split(batch_size, dim=0) - per_token_logps = self._get_batch_logps(all_logits, concatenated_batch['labels']) - chosen_idxs = batch['inputs_w']['input_ids'].shape[0] - rejected_idx = batch['inputs_l']['input_ids'].shape[0] + all_logps = self._get_batch_logps(all_logits, concatenated_batch['labels']) - chosen_logits = all_logits[:chosen_idxs] - rejected_logits = all_logits[chosen_idxs:] + public_ = chosen_mask * rejected_mask + public_mask = torch.cat([public_, public_]) + public_logps = all_logps * public_mask + all_logps = self.lc_alpha * all_logps + (1 - self.lc_alpha) * public_logps - chosen_per_token_logps = per_token_logps[:chosen_idxs] - rejected_per_token_logps = per_token_logps[chosen_idxs : chosen_idxs + rejected_idx] - - chosen_loss_mask = loss_mask[:chosen_idxs] - rejected_loss_mask = loss_mask[chosen_idxs : chosen_idxs + rejected_idx] - - min_lengths = torch.min(chosen_loss_mask.sum(-1), rejected_loss_mask.sum(-1)) - - answer_start_idx = torch.argmax( - chosen_loss_mask.int(), -1 - ) # The index of the beginning of the chosen and rejected - split_start_idx = answer_start_idx + min_lengths # Add the length of the shorter answer - - # Setting the increment mask - alpha_mask = torch.arange(torch.full_like(chosen_loss_mask, 1).size(1)).unsqueeze( - 0 - ) >= split_start_idx.unsqueeze(1) - - # Incrementing by alpha logprobs that are out of bounds - chosen_per_token_logps[alpha_mask] = chosen_per_token_logps[alpha_mask] ** self.alpha - rejected_per_token_logps[alpha_mask] = rejected_per_token_logps[alpha_mask] ** self.alpha - - if self.average_log_prob: - chosen_logps = (chosen_per_token_logps * chosen_loss_mask).sum(-1) / chosen_loss_mask.sum(-1) - rejected_logps = (rejected_per_token_logps * rejected_loss_mask).sum(-1) / rejected_loss_mask.sum(-1) - chosen_logps = (chosen_per_token_logps * chosen_loss_mask).sum(-1) - rejected_logps = (rejected_per_token_logps * rejected_loss_mask).sum(-1) + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins From 8c61d1302eb6e0b47b44e20e3601d4a13d872403 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Thu, 28 Nov 2024 13:11:40 +0000 Subject: [PATCH 3/6] fix with sum --- turbo_alignment/trainers/lddpo.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/turbo_alignment/trainers/lddpo.py b/turbo_alignment/trainers/lddpo.py index 45ccc57..e0f88aa 100644 --- a/turbo_alignment/trainers/lddpo.py +++ b/turbo_alignment/trainers/lddpo.py @@ -53,11 +53,12 @@ def _get_batch_logps(self, logits: torch.Tensor, labels: torch.Tensor) -> torch. labels = labels[:, 1:].clone() logits = logits[:, :-1, :] + loss_mask = labels != DISABLE_LOSS_LABEL labels[labels == DISABLE_LOSS_LABEL] = 0 per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - return per_token_logps + return (per_token_logps * loss_mask), loss_mask def concatenated_forward( self, model: nn.Module, batch: dict[str, Any] @@ -71,18 +72,18 @@ def concatenated_forward( attention_mask=concatenated_batch['attention_mask'], ).logits.to(torch.float32) - loss_mask = concatenated_batch['labels'][:, 1:] != DISABLE_LOSS_LABEL + all_logps, loss_mask = self._get_batch_logps(all_logits, concatenated_batch['labels']) + batch_size = concatenated_batch['input_ids'].size(0) // 2 chosen_mask, rejected_mask = loss_mask.split(batch_size, dim=0) - all_logps = self._get_batch_logps(all_logits, concatenated_batch['labels']) - public_ = chosen_mask * rejected_mask public_mask = torch.cat([public_, public_]) public_logps = all_logps * public_mask + all_logps = self.lc_alpha * all_logps + (1 - self.lc_alpha) * public_logps chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) - return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins + return chosen_logps.sum(-1), rejected_logps.sum(-1), chosen_logits, rejected_logits, precomputed_margins From 7a6b401beff85938e8159a32415a4f94480d4a92 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Thu, 28 Nov 2024 13:16:39 +0000 Subject: [PATCH 4/6] pretty --- turbo_alignment/settings/pipelines/train/lddpo.py | 2 +- turbo_alignment/trainers/lddpo.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/turbo_alignment/settings/pipelines/train/lddpo.py b/turbo_alignment/settings/pipelines/train/lddpo.py index 5c953f8..c733d58 100644 --- a/turbo_alignment/settings/pipelines/train/lddpo.py +++ b/turbo_alignment/settings/pipelines/train/lddpo.py @@ -5,7 +5,7 @@ class LDDPOTrainerSettings(DPOTrainerSettings): - lc_alpha: float = 1.0 + lc_alpha: float class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings): diff --git a/turbo_alignment/trainers/lddpo.py b/turbo_alignment/trainers/lddpo.py index e0f88aa..ab97561 100644 --- a/turbo_alignment/trainers/lddpo.py +++ b/turbo_alignment/trainers/lddpo.py @@ -79,6 +79,7 @@ def concatenated_forward( public_ = chosen_mask * rejected_mask public_mask = torch.cat([public_, public_]) + public_logps = all_logps * public_mask all_logps = self.lc_alpha * all_logps + (1 - self.lc_alpha) * public_logps From bb614286e30a80a24000f8e711989b5c664213a1 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Thu, 28 Nov 2024 13:17:17 +0000 Subject: [PATCH 5/6] ld-dpo config --- tests/fixtures/configs/train/lddpo/base.json | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/fixtures/configs/train/lddpo/base.json b/tests/fixtures/configs/train/lddpo/base.json index 3140625..43eae2a 100644 --- a/tests/fixtures/configs/train/lddpo/base.json +++ b/tests/fixtures/configs/train/lddpo/base.json @@ -112,6 +112,7 @@ "pad_token": "" }, "trainer_settings": { + "lc_alpha": 0.5, "evaluation_strategy": "steps", "per_device_train_batch_size": 2, "per_device_eval_batch_size": 2, From 61ea543af844ed5b544a3fa93eaa022b1c476ef9 Mon Sep 17 00:00:00 2001 From: Elisei Rykov Date: Wed, 11 Dec 2024 13:13:39 +0300 Subject: [PATCH 6/6] in progress --- turbo_alignment/pipelines/train/lddpo.py | 33 ++++++++++---------- turbo_alignment/trainers/dpo.py | 38 +++++++++++++++--------- 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/turbo_alignment/pipelines/train/lddpo.py b/turbo_alignment/pipelines/train/lddpo.py index 611477f..7eb3d51 100644 --- a/turbo_alignment/pipelines/train/lddpo.py +++ b/turbo_alignment/pipelines/train/lddpo.py @@ -37,22 +37,23 @@ def _get_cherry_pick_callback( tokenizer: PreTrainedTokenizerBase, **kwargs, ) -> ChatCherryPickCallback: - cherry_pick_settings = experiment_settings.cherry_pick_settings - - cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( - cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE - ) - - metrics = [ - Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) - for metric in cherry_pick_settings.metric_settings - ] - - return ChatCherryPickCallback( - cherry_pick_settings=cherry_pick_settings, - datasets=cherry_pick_datasets, - metrics=metrics, - ) + return None + # cherry_pick_settings = experiment_settings.cherry_pick_settings + + # cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( + # cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE + # ) + + # metrics = [ + # Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) + # for metric in cherry_pick_settings.metric_settings + # ] + + # return ChatCherryPickCallback( + # cherry_pick_settings=cherry_pick_settings, + # datasets=cherry_pick_datasets, + # metrics=metrics, + # ) def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: logger.info(f'Train sample example:\n{dataset[0]}') diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index ef3288f..94dd7b6 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -606,12 +606,18 @@ def _get_batch_logps( return (per_token_logps * loss_mask).sum(-1) def concatenated_forward( - self, model: nn.Module, batch: dict[str, Any] + self, + model: nn.Module, + batch: dict[str, Any], + get_from_dataset=False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None) + if get_from_dataset: + return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps') + all_logits = model( concatenated_batch['input_ids'], attention_mask=concatenated_batch['attention_mask'], @@ -632,16 +638,19 @@ def concatenated_forward( return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: - with torch.no_grad(): - if model is not None: - (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ( - chosen_logps, - rejected_logps, - *_, - ) = self.concatenated_forward(self.model, batch) + if self.ref_model is None: + chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True) + else: + with torch.no_grad(): + if model is not None: + (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ( + chosen_logps, + rejected_logps, + *_, + ) = self.concatenated_forward(self.model, batch) return chosen_logps, rejected_logps @@ -780,9 +789,9 @@ def _compute_metrics( metrics[f'{prefix_name}grad_term'] = ( (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item() ) - metrics[f'{prefix_name}grad_term_std'] = ( - (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item() - ) + # metrics[f'{prefix_name}grad_term_std'] = ( + # (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item() + # ) return metrics @@ -822,6 +831,7 @@ def compute_loss( model: PreTrainedModel | nn.Module, inputs: dict[str, torch.Tensor | Any], return_outputs=False, + num_items_in_batch=None, ) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]: loss, metrics = self.get_batch_metrics(model, inputs, train_eval='train')