From 8c61d1302eb6e0b47b44e20e3601d4a13d872403 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Thu, 28 Nov 2024 13:11:40 +0000 Subject: [PATCH] fix with sum --- turbo_alignment/trainers/lddpo.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/turbo_alignment/trainers/lddpo.py b/turbo_alignment/trainers/lddpo.py index 45ccc57..e0f88aa 100644 --- a/turbo_alignment/trainers/lddpo.py +++ b/turbo_alignment/trainers/lddpo.py @@ -53,11 +53,12 @@ def _get_batch_logps(self, logits: torch.Tensor, labels: torch.Tensor) -> torch. labels = labels[:, 1:].clone() logits = logits[:, :-1, :] + loss_mask = labels != DISABLE_LOSS_LABEL labels[labels == DISABLE_LOSS_LABEL] = 0 per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - return per_token_logps + return (per_token_logps * loss_mask), loss_mask def concatenated_forward( self, model: nn.Module, batch: dict[str, Any] @@ -71,18 +72,18 @@ def concatenated_forward( attention_mask=concatenated_batch['attention_mask'], ).logits.to(torch.float32) - loss_mask = concatenated_batch['labels'][:, 1:] != DISABLE_LOSS_LABEL + all_logps, loss_mask = self._get_batch_logps(all_logits, concatenated_batch['labels']) + batch_size = concatenated_batch['input_ids'].size(0) // 2 chosen_mask, rejected_mask = loss_mask.split(batch_size, dim=0) - all_logps = self._get_batch_logps(all_logits, concatenated_batch['labels']) - public_ = chosen_mask * rejected_mask public_mask = torch.cat([public_, public_]) public_logps = all_logps * public_mask + all_logps = self.lc_alpha * all_logps + (1 - self.lc_alpha) * public_logps chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) - return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins + return chosen_logps.sum(-1), rejected_logps.sum(-1), chosen_logits, rejected_logits, precomputed_margins