From e0acd61b6e9faec52150ff9b99a7c1ea35fe577e Mon Sep 17 00:00:00 2001 From: oltsy Date: Sun, 1 Dec 2024 02:08:35 +0300 Subject: [PATCH 1/4] Add: nca pair loss --- .../settings/pipelines/train/dpo.py | 6 ++++ turbo_alignment/trainers/dpo.py | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index bc85133..3d35fa3 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -26,6 +26,7 @@ class DPOLossesType(str, Enum): APO_DOWN = 'apo_down' ASFT = 'asft' DPOP = 'dpop' + NCA_PAIT = 'nca_pair' class DPOLossSettings(ExtraFieldsNotAllowedBaseModel): @@ -93,6 +94,10 @@ class DPOPLossSettings(DPOLossSettings): lam: float = 0.1 +class NCAPairLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.NCA_PAIT] + + class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): sync_ref_model: bool = False alpha: float = 1.0 @@ -114,6 +119,7 @@ class DPOTrainerSettings(TrainerSettings): | APOZeroLossSettings | APODownLossSettings | DPOPLossSettings + | NCAPairLossSettings ) sync_ref_settings: SyncRefModelSettings use_ref_model: bool = True diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index ef3288f..7473987 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -39,6 +39,7 @@ SimPOLossSettings, SlicHfLossSettings, SyncRefModelSettings, + NCAPairLossSettings ) from turbo_alignment.trainers.utils import ( DPOLossRegistry, @@ -456,6 +457,35 @@ def compute_loss( return loss, chosen_rewards, rejected_rewards +@DPOLossRegistry.register(DPOLossesType.NCA_PAIT) +class NCAPairLoss(DPOLossRegistry): + def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: + self.beta = beta + super().__init__(*args, **kwargs) + + def compute_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + precomputed_margins: torch.FloatTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + loss = ( + -F.logsigmoid(chosen_logratios * self.beta) + - 0.5 * F.logsigmoid(-chosen_logratios * self.beta) + - 0.5 * F.logsigmoid(-rejected_logratios * self.beta) + ) + + return loss, chosen_rewards, rejected_rewards + + @dataclass class DPOTrainingArguments(TrainingArguments): loss_settings: ( @@ -473,6 +503,7 @@ class DPOTrainingArguments(TrainingArguments): | APOZeroLossSettings | APODownLossSettings | DPOPLossSettings + | NCAPairLossSettings ) = field( default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) ) # type: ignore[call-overload] From bdbd3794aafbde9ed9cbca0c9c7cd37c7cfc5a34 Mon Sep 17 00:00:00 2001 From: oltsy Date: Mon, 2 Dec 2024 11:43:51 +0300 Subject: [PATCH 2/4] Fix: linters --- turbo_alignment/trainers/dpo.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 7473987..40e9091 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -33,13 +33,13 @@ HingeLossSettings, IPOLossSettings, KTOLossSettings, + NCAPairLossSettings, ORPOLossSettings, SigmoidLossSettings, SigmoidLossWithMarginSettings, SimPOLossSettings, SlicHfLossSettings, SyncRefModelSettings, - NCAPairLossSettings ) from turbo_alignment.trainers.utils import ( DPOLossRegistry, @@ -482,7 +482,7 @@ def compute_loss( - 0.5 * F.logsigmoid(-chosen_logratios * self.beta) - 0.5 * F.logsigmoid(-rejected_logratios * self.beta) ) - + return loss, chosen_rewards, rejected_rewards @@ -853,6 +853,7 @@ def compute_loss( model: PreTrainedModel | nn.Module, inputs: dict[str, torch.Tensor | Any], return_outputs=False, + num_items_in_batch=None, ) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]: loss, metrics = self.get_batch_metrics(model, inputs, train_eval='train') From 627b1f006e386c2746f81d7fc131d5c91d395aed Mon Sep 17 00:00:00 2001 From: Alexey Malakhov <131314005+alekseymalakhov11@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:25:45 +0300 Subject: [PATCH 3/4] fix name --- turbo_alignment/settings/pipelines/train/dpo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index 3d35fa3..c08b8b8 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -26,7 +26,7 @@ class DPOLossesType(str, Enum): APO_DOWN = 'apo_down' ASFT = 'asft' DPOP = 'dpop' - NCA_PAIT = 'nca_pair' + NCA_PAIR = 'nca_pair' class DPOLossSettings(ExtraFieldsNotAllowedBaseModel): @@ -95,7 +95,7 @@ class DPOPLossSettings(DPOLossSettings): class NCAPairLossSettings(DPOLossSettings): - loss_type: Literal[DPOLossesType.NCA_PAIT] + loss_type: Literal[DPOLossesType.NCA_PAIR] class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): From cc063a954bb3805ba8dc06961e98a6c757f6bf62 Mon Sep 17 00:00:00 2001 From: Alexey Malakhov <131314005+alekseymalakhov11@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:35:58 +0300 Subject: [PATCH 4/4] fix tests --- turbo_alignment/trainers/dpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 40e9091..2278eee 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -457,7 +457,7 @@ def compute_loss( return loss, chosen_rewards, rejected_rewards -@DPOLossRegistry.register(DPOLossesType.NCA_PAIT) +@DPOLossRegistry.register(DPOLossesType.NCA_PAIR) class NCAPairLoss(DPOLossRegistry): def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: self.beta = beta