Skip to content

Commit

Permalink
Fix hf_model loss
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed May 4, 2024
1 parent d85fbc8 commit b4003c7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions open_lm/utils/transformers/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ def forward(
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1).to(shift_logits.device)
if loss_mask is not None:
shift_mask = loss_mask[..., 1:].contiguous()
shift_mask = loss_mask[..., :-1].contiguous()
loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits, shift_labels)
loss = loss[shift_mask.view(-1)].sum() / shift_mask.sum()
shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100)
loss = loss[shift_mask.view(-1)].sum()/shift_mask.sum()
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
Expand Down

0 comments on commit b4003c7

Please sign in to comment.