Skip to content

Commit

Permalink
fix with sum
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Nov 28, 2024
1 parent e083ef6 commit 8c61d13
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions turbo_alignment/trainers/lddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ def _get_batch_logps(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.

labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != DISABLE_LOSS_LABEL

labels[labels == DISABLE_LOSS_LABEL] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return per_token_logps
return (per_token_logps * loss_mask), loss_mask

def concatenated_forward(
self, model: nn.Module, batch: dict[str, Any]
Expand All @@ -71,18 +72,18 @@ def concatenated_forward(
attention_mask=concatenated_batch['attention_mask'],
).logits.to(torch.float32)

loss_mask = concatenated_batch['labels'][:, 1:] != DISABLE_LOSS_LABEL
all_logps, loss_mask = self._get_batch_logps(all_logits, concatenated_batch['labels'])

batch_size = concatenated_batch['input_ids'].size(0) // 2
chosen_mask, rejected_mask = loss_mask.split(batch_size, dim=0)

all_logps = self._get_batch_logps(all_logits, concatenated_batch['labels'])

public_ = chosen_mask * rejected_mask
public_mask = torch.cat([public_, public_])
public_logps = all_logps * public_mask

all_logps = self.lc_alpha * all_logps + (1 - self.lc_alpha) * public_logps

chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)

return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins
return chosen_logps.sum(-1), rejected_logps.sum(-1), chosen_logits, rejected_logits, precomputed_margins

0 comments on commit 8c61d13

Please sign in to comment.