From 4f20d9634fb11ce8ade320e20d035e5cf0c0ef6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Tue, 3 Sep 2024 14:46:49 +0000 Subject: [PATCH 1/2] add precomputed margins and add concatenated forward for rm --- turbo_alignment/dataset/chat/models.py | 8 +- .../dataset/pair_preferences/collators.py | 19 +++-- .../dataset/pair_preferences/models.py | 2 +- .../pair_preferences/pair_preference.py | 8 +- turbo_alignment/pipelines/sampling/rm.py | 10 ++- .../settings/pipelines/train/dpo.py | 4 + turbo_alignment/trainers/dpo.py | 76 ++++++++++++++----- turbo_alignment/trainers/rm.py | 22 ++++-- turbo_alignment/trainers/utils.py | 15 +++- 9 files changed, 119 insertions(+), 45 deletions(-) diff --git a/turbo_alignment/dataset/chat/models.py b/turbo_alignment/dataset/chat/models.py index 314de55..139f2ca 100755 --- a/turbo_alignment/dataset/chat/models.py +++ b/turbo_alignment/dataset/chat/models.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from turbo_alignment.dataset.base.models import DatasetRecord @@ -17,6 +17,12 @@ class ChatMessage(BaseModel): content: str disable_loss: bool = False + @field_validator('role', mode='before') + def set_bot_role(cls, values: str) -> str: + if values == 'assistant': + return ChatMessageRole.BOT + return values + class ChatDatasetRecord(DatasetRecord): messages: list[ChatMessage] diff --git a/turbo_alignment/dataset/pair_preferences/collators.py b/turbo_alignment/dataset/pair_preferences/collators.py index c0c6bfb..d19b16c 100755 --- a/turbo_alignment/dataset/pair_preferences/collators.py +++ b/turbo_alignment/dataset/pair_preferences/collators.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Iterable import torch import transformers @@ -19,7 +19,9 @@ def _get_batch( ) -> transformers.BatchEncoding: features = [ex[key] for ex in examples] labels = [v.tolist() for feature in features for k, v in feature.items() if k == 'labels'] - no_labels_features = [{k: v for k, v in feature.items() if k != 'labels'} for feature in features] + no_labels_features = [ + {k: v for k, v in feature.items() if k not in ['labels', 'precomputed_margin']} for feature in features + ] batch = tokenizer.pad( no_labels_features, @@ -38,13 +40,16 @@ def __call__(self, examples: list[dict[str, dict[str, Any]]]) -> dict[str, Any]: max_length = 0 for ex in examples: for t in ex: - if 'input_ids' in ex[t]: - max_length = max(max_length, len(ex[t]['input_ids'])) + if isinstance(ex[t], Iterable): + if 'input_ids' in ex[t]: + max_length = max(max_length, len(ex[t]['input_ids'])) - batch = { + batch: dict[str, Any] = { 'inputs_w': dict(self._get_batch(examples, self.tokenizer, 'inputs_w', max_length)), 'inputs_l': dict(self._get_batch(examples, self.tokenizer, 'inputs_l', max_length)), } - if 'best_decode' in examples[0] and len(examples[0]['best_decode']) != 0: - batch['best_decode'] = dict(self._get_batch(examples, self.tokenizer, 'best_decode', max_length)) + + if 'precomputed_margin' in examples[0] and examples[0]['precomputed_margin'] is not None: + batch['precomputed_margin'] = torch.tensor([ex['precomputed_margin'] for ex in examples]) + return batch diff --git a/turbo_alignment/dataset/pair_preferences/models.py b/turbo_alignment/dataset/pair_preferences/models.py index 64c1f23..4cb099b 100755 --- a/turbo_alignment/dataset/pair_preferences/models.py +++ b/turbo_alignment/dataset/pair_preferences/models.py @@ -11,4 +11,4 @@ class PairPreferenceRecord(DatasetRecord): context: list[ChatMessage] answer_w: ChatMessage answer_l: ChatMessage - best_decode: ChatMessage | None = None + precomputed_margin: float | None = None diff --git a/turbo_alignment/dataset/pair_preferences/pair_preference.py b/turbo_alignment/dataset/pair_preferences/pair_preference.py index 6ffe661..e01a5fe 100755 --- a/turbo_alignment/dataset/pair_preferences/pair_preference.py +++ b/turbo_alignment/dataset/pair_preferences/pair_preference.py @@ -82,23 +82,19 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str, if not (chosen_record and rejected_record): continue - ignore_keys = [] + ignore_keys = ['precomputed_margin'] if not self._add_labels: ignore_keys.append('labels') chosen_tokens = {k: v.squeeze(0) for k, v in chosen_record.items() if k not in ignore_keys} rejected_tokens = {k: v.squeeze(0) for k, v in rejected_record.items() if k not in ignore_keys} - best_decode_tokens = {} - - if best_decode_record is not None: - best_decode_tokens = {k: v.squeeze(0) for k, v in best_decode_record.items() if k not in ignore_keys} output.append( { 'id': record.id, 'inputs_w': chosen_tokens, 'inputs_l': rejected_tokens, - 'best_decode': best_decode_tokens, + 'precomputed_margin': record.precomputed_margin, } ) diff --git a/turbo_alignment/pipelines/sampling/rm.py b/turbo_alignment/pipelines/sampling/rm.py index 6b68ed8..49c7783 100755 --- a/turbo_alignment/pipelines/sampling/rm.py +++ b/turbo_alignment/pipelines/sampling/rm.py @@ -3,6 +3,8 @@ from accelerate import Accelerator from accelerate.utils import set_seed +from accelerate.utils.operations import gather_object + from transformers import PreTrainedModel, PreTrainedTokenizerBase from turbo_alignment.common.tf.loaders.model.model import load_model @@ -36,7 +38,9 @@ def _get_rewards( batch=experiment_settings.rm_batch_size, micro_batch=experiment_settings.rm_batch_size, ) - outputs: list[RMSamplingInferenceOutput] = generator.generate_from_dataset(dataset) + outputs: list[RMSamplingInferenceOutput] = gather_object(generator.generate_from_dataset(dataset))[ + : len(dataset) + ] return {out.id: out.rewards for out in outputs} @staticmethod @@ -59,7 +63,9 @@ def sample(self, experiment_settings: SamplingSettingsWithRMT) -> list[SamplingD model = load_model(experiment_settings.rm, tokenizer) if accelerator is not None: - model.to(accelerator.device) + model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True) + + model.eval() dataset = SamplingRMDataset( source=experiment_settings.dataset_settings.sources[0], diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index 64d0a43..b29e277 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -14,6 +14,7 @@ class DPOLossesType(str, Enum): SIGMOID = 'sigmoid' + SIGMOD_WITH_MARGIN = 'sigmoid_with_margin' HINGE = 'hinge' IPO = 'ipo' KTO = 'kto' @@ -36,6 +37,8 @@ class SigmoidLossSettings(DPOLossSettings): loss_type: Literal[DPOLossesType.SIGMOID] label_smoothing: float = 0 +class SigmoidLossWithMarginSettings(SigmoidLossSettings): + loss_type: Literal[DPOLossesType.SIGMOD_WITH_MARGIN] class HingeLossSettings(DPOLossSettings): loss_type: Literal[DPOLossesType.HINGE] @@ -86,6 +89,7 @@ class DPOTrainerSettings(TrainerSettings): | ORPOLossSettings | SimPOLossSettings | SlicHfLossSettings + | SigmoidLossWithMarginSettings ) sync_ref_settings: SyncRefModelSettings use_ref_model: bool = True diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index ff17c91..f9b654a 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -57,7 +57,7 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps @@ -91,7 +91,7 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) @@ -125,7 +125,7 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps @@ -153,7 +153,7 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps @@ -184,7 +184,7 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps @@ -216,7 +216,7 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps @@ -231,8 +231,8 @@ def compute_loss( if self.norm: loss = torch.relu(self.delta - self.beta * logits) - if policy_best_decode_logps is not None: - loss = loss - self.lam * policy_best_decode_logps + if precomputed_margins is not None: + loss = loss - self.lam * precomputed_margins return loss, chosen_rewards, rejected_rewards @@ -250,7 +250,7 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor | None, reference_rejected_logps: torch.FloatTensor | None, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps @@ -280,7 +280,7 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor | None, reference_rejected_logps: torch.FloatTensor | None, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: log_odds = (policy_chosen_logps - policy_rejected_logps) - ( torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7)) @@ -294,6 +294,42 @@ def compute_loss( rejected_rewards = self.beta * policy_rejected_logps.detach() return losses, chosen_rewards, rejected_rewards + + +@DPOLossRegistry.register(DPOLossesType.SIGMOID_WITH_MARGIN) +class SigmoidLossWithMargin(DPOLossRegistry): + def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: + self.beta = beta + super().__ini__(*args, **kwargs) + + + def compute_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + logits = pi_logratios - ref_logratios + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + if precomputed_margins is None: + raise ValueError('Precomputed margins should not be none when using SigmoidLossWithMargin') + + loss = -F.logsigmoid(self.beta * logits - precomputed_margins) + + return ( + loss, + chosen_rewards, + rejected_rewards, + ) @dataclass @@ -308,6 +344,7 @@ class DPOTrainingArguments(TrainingArguments): | ORPOLossSettings | SimPOLossSettings | SlicHfLossSettings + | SigmoidLossWithMarginSettings ) = field( default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) ) # type: ignore[call-overload] @@ -402,14 +439,14 @@ def dpo_loss( policy_rejected_logps: torch.Tensor, reference_chosen_logps: torch.Tensor, reference_rejected_logps: torch.Tensor, - policy_best_decode_logps: torch.Tensor | None, + precomputed_margins: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.dpo_loss_registry.compute_loss( policy_chosen_logps=policy_chosen_logps, policy_rejected_logps=policy_rejected_logps, reference_chosen_logps=reference_chosen_logps, reference_rejected_logps=reference_rejected_logps, - policy_best_decode_logps=policy_best_decode_logps, + precomputed_margins=precomputed_margins, ) def _get_batch_logps( @@ -437,6 +474,9 @@ def concatenated_forward( self, model: nn.Module, batch: dict[str, Any] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) + + precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None) + all_logits = model( concatenated_batch['input_ids'], attention_mask=concatenated_batch['attention_mask'], @@ -452,13 +492,9 @@ def concatenated_forward( chosen_logps = all_logps[:chosen_idxs] rejected_logps = all_logps[chosen_idxs : chosen_idxs + rejected_idx] - policy_best_decode_logps: torch.Tensor = all_logps[chosen_idxs + rejected_idx :] - if len(policy_best_decode_logps) == 0: - policy_best_decode_logps = None # type: ignore[assignment] - chosen_logits = all_logits[:chosen_idxs] rejected_logits = all_logits[chosen_idxs:] - return chosen_logps, rejected_logps, chosen_logits, rejected_logits, policy_best_decode_logps + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): @@ -487,7 +523,7 @@ def get_batch_metrics( policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, - policy_best_decode_logps, + precomputed_margins, ) = self.concatenated_forward(model, batch) reference_chosen_logps, reference_rejected_logps = torch.Tensor([float('inf')]), torch.Tensor([float('inf')]) @@ -500,7 +536,7 @@ def get_batch_metrics( policy_rejected_logps=policy_rejected_logps, reference_chosen_logps=reference_chosen_logps, reference_rejected_logps=reference_rejected_logps, - policy_best_decode_logps=policy_best_decode_logps, + precomputed_margins=precomputed_margins, ) prefix = 'eval_' if train_eval == 'eval' else '' @@ -561,7 +597,7 @@ def get_batch_metrics( policy_rejected_logps=policy_rejected_logps, reference_chosen_logps=sft_chosen_logps, reference_rejected_logps=sft_rejected_logps, - policy_best_decode_logps=policy_best_decode_logps, + precomputed_margins=precomputed_margins, ) sft_prefix_name = prefix + 'rewards/sft_' diff --git a/turbo_alignment/trainers/rm.py b/turbo_alignment/trainers/rm.py index 6c119dc..2ecf755 100755 --- a/turbo_alignment/trainers/rm.py +++ b/turbo_alignment/trainers/rm.py @@ -7,17 +7,29 @@ from transformers.utils import logging from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer +from turbo_alignment.trainers.utils import concatenated_inputs + logger = logging.get_logger(__name__) class RMTrainer(MultiGPUCherryPicksTrainer): - def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: - inputs_w = inputs['inputs_w'] - inputs_l = inputs['inputs_l'] - rewards_w = model(**inputs_w, return_dict=True)[0] - rewards_l = model(**inputs_l, return_dict=True)[0] + def concatenated_forward(self, model: nn.Module, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: + concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) + all_rewards = model( + concatenated_batch['input_ids'], attention_mask=concatenated_batch['attention_mask'], return_dict=True + )[0] + + chosen_idxs = batch['inputs_w']['input_ids'].shape[0] + + chosen_rewards = all_rewards[:chosen_idxs] + rejected_rewards = all_rewards[chosen_idxs:] + + return chosen_rewards, rejected_rewards + + def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: + rewards_w, rewards_l = self.concatenated_forward(model, inputs) loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() if return_outputs: diff --git a/turbo_alignment/trainers/utils.py b/turbo_alignment/trainers/utils.py index 8021d0a..e3fabf8 100755 --- a/turbo_alignment/trainers/utils.py +++ b/turbo_alignment/trainers/utils.py @@ -19,15 +19,24 @@ def concatenated_inputs( """ grouped_batch: dict[str, list[torch.Tensor]] = defaultdict(list) + no_grouped_batch_items: dict[str, Any] = {} for outcome_key, outcome_inputs in batch.items(): if outcome_key.startswith(prefix): - for k, v in outcome_inputs.items(): - grouped_batch[k].append(v) + if isinstance(outcome_inputs, dict): + for k, v in outcome_inputs.items(): + grouped_batch[k].append(v) + else: + no_grouped_batch_items[outcome_key] = outcome_inputs concatenated_batch: dict[str, torch.Tensor] = {} for k, v in grouped_batch.items(): concatenated_batch[k] = torch.cat(v, dim=0).to(device) + for k, v in no_grouped_batch_items.items(): + if isinstance(v, torch.Tensor): + v = v.to(device) + concatenated_batch[k] = v + return concatenated_batch @@ -86,6 +95,6 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, + precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... From 323e99f1e9db3be295ad6bdafb62323ca2e9986a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Tue, 3 Sep 2024 15:34:30 +0000 Subject: [PATCH 2/2] fix linters and tests --- .../dataset/pair_preferences/pair_preference.py | 14 ++------------ turbo_alignment/pipelines/sampling/rm.py | 3 +-- turbo_alignment/settings/pipelines/train/dpo.py | 8 +++++--- turbo_alignment/trainers/dpo.py | 7 +++---- turbo_alignment/trainers/rm.py | 2 -- 5 files changed, 11 insertions(+), 23 deletions(-) diff --git a/turbo_alignment/dataset/pair_preferences/pair_preference.py b/turbo_alignment/dataset/pair_preferences/pair_preference.py index e01a5fe..fc6ff70 100755 --- a/turbo_alignment/dataset/pair_preferences/pair_preference.py +++ b/turbo_alignment/dataset/pair_preferences/pair_preference.py @@ -49,7 +49,6 @@ def __init__( def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str, Any] | None]: chosen_chat_records: list[ChatDatasetRecord] = [] rejected_chat_records: list[ChatDatasetRecord] = [] - best_decode_records: list[ChatDatasetRecord] = [] for record in records: context = [ @@ -60,24 +59,15 @@ def convert_records(self, records: list[PairPreferenceRecord]) -> list[dict[str, chosen = ChatMessage(role=record.answer_w.role, content=record.answer_w.content) rejected = ChatMessage(role=record.answer_l.role, content=record.answer_l.content) - if record.best_decode is not None: - best_decode = ChatMessage(role=record.best_decode.role, content=record.best_decode.content) - best_decode_records.append(ChatDatasetRecord(id=record.id, messages=context + [best_decode])) - chosen_chat_records.append(ChatDatasetRecord(id=record.id, messages=context + [chosen])) rejected_chat_records.append(ChatDatasetRecord(id=record.id, messages=context + [rejected])) tokenized_chosen_records = self._chat_dataset.convert_records(chosen_chat_records) tokenized_rejected_records = self._chat_dataset.convert_records(rejected_chat_records) - if len(best_decode_records) != 0: - tokenized_best_decode_records = self._chat_dataset.convert_records(best_decode_records) - else: - tokenized_best_decode_records = [None] * len(tokenized_chosen_records) - output: list[dict[str, Any] | None] = [] - for record, chosen_record, rejected_record, best_decode_record in zip( - records, tokenized_chosen_records, tokenized_rejected_records, tokenized_best_decode_records + for record, chosen_record, rejected_record in zip( + records, tokenized_chosen_records, tokenized_rejected_records ): if not (chosen_record and rejected_record): continue diff --git a/turbo_alignment/pipelines/sampling/rm.py b/turbo_alignment/pipelines/sampling/rm.py index 49c7783..86b9f8c 100755 --- a/turbo_alignment/pipelines/sampling/rm.py +++ b/turbo_alignment/pipelines/sampling/rm.py @@ -4,7 +4,6 @@ from accelerate import Accelerator from accelerate.utils import set_seed from accelerate.utils.operations import gather_object - from transformers import PreTrainedModel, PreTrainedTokenizerBase from turbo_alignment.common.tf.loaders.model.model import load_model @@ -64,7 +63,7 @@ def sample(self, experiment_settings: SamplingSettingsWithRMT) -> list[SamplingD model = load_model(experiment_settings.rm, tokenizer) if accelerator is not None: model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True) - + model.eval() dataset = SamplingRMDataset( diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index b29e277..9464031 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -14,7 +14,7 @@ class DPOLossesType(str, Enum): SIGMOID = 'sigmoid' - SIGMOD_WITH_MARGIN = 'sigmoid_with_margin' + SIGMOID_WITH_MARGIN = 'sigmoid_with_margin' HINGE = 'hinge' IPO = 'ipo' KTO = 'kto' @@ -37,8 +37,10 @@ class SigmoidLossSettings(DPOLossSettings): loss_type: Literal[DPOLossesType.SIGMOID] label_smoothing: float = 0 -class SigmoidLossWithMarginSettings(SigmoidLossSettings): - loss_type: Literal[DPOLossesType.SIGMOD_WITH_MARGIN] + +class SigmoidLossWithMarginSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.SIGMOID_WITH_MARGIN] + class HingeLossSettings(DPOLossSettings): loss_type: Literal[DPOLossesType.HINGE] diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index f9b654a..0ce83db 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -31,6 +31,7 @@ KTOLossSettings, ORPOLossSettings, SigmoidLossSettings, + SigmoidLossWithMarginSettings, SimPOLossSettings, SlicHfLossSettings, SyncRefModelSettings, @@ -294,14 +295,13 @@ def compute_loss( rejected_rewards = self.beta * policy_rejected_logps.detach() return losses, chosen_rewards, rejected_rewards - + @DPOLossRegistry.register(DPOLossesType.SIGMOID_WITH_MARGIN) class SigmoidLossWithMargin(DPOLossRegistry): def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: self.beta = beta - super().__ini__(*args, **kwargs) - + super().__init__(*args, **kwargs) def compute_loss( self, @@ -309,7 +309,6 @@ def compute_loss( policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, - policy_best_decode_logps: torch.FloatTensor | None, precomputed_margins: torch.FloatTensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps diff --git a/turbo_alignment/trainers/rm.py b/turbo_alignment/trainers/rm.py index 2ecf755..d78a7e9 100755 --- a/turbo_alignment/trainers/rm.py +++ b/turbo_alignment/trainers/rm.py @@ -9,12 +9,10 @@ from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer from turbo_alignment.trainers.utils import concatenated_inputs - logger = logging.get_logger(__name__) class RMTrainer(MultiGPUCherryPicksTrainer): - def concatenated_forward(self, model: nn.Module, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) all_rewards = model(