From 7a6b401beff85938e8159a32415a4f94480d4a92 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Thu, 28 Nov 2024 13:16:39 +0000 Subject: [PATCH] pretty --- turbo_alignment/settings/pipelines/train/lddpo.py | 2 +- turbo_alignment/trainers/lddpo.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/turbo_alignment/settings/pipelines/train/lddpo.py b/turbo_alignment/settings/pipelines/train/lddpo.py index 5c953f8..c733d58 100644 --- a/turbo_alignment/settings/pipelines/train/lddpo.py +++ b/turbo_alignment/settings/pipelines/train/lddpo.py @@ -5,7 +5,7 @@ class LDDPOTrainerSettings(DPOTrainerSettings): - lc_alpha: float = 1.0 + lc_alpha: float class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings): diff --git a/turbo_alignment/trainers/lddpo.py b/turbo_alignment/trainers/lddpo.py index e0f88aa..ab97561 100644 --- a/turbo_alignment/trainers/lddpo.py +++ b/turbo_alignment/trainers/lddpo.py @@ -79,6 +79,7 @@ def concatenated_forward( public_ = chosen_mask * rejected_mask public_mask = torch.cat([public_, public_]) + public_logps = all_logps * public_mask all_logps = self.lc_alpha * all_logps + (1 - self.lc_alpha) * public_logps