Skip to content

Commit

Permalink
fix mask shift
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed May 8, 2024
1 parent 46a45a4 commit 8bc85e8
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion open_lm/utils/transformers/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def forward(
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1).to(shift_logits.device)
if loss_mask is not None:
shift_mask = loss_mask[..., :-1].contiguous()
shift_mask = loss_mask[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits, shift_labels)
shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100)
Expand Down

0 comments on commit 8bc85e8

Please sign in to comment.