From 8bc85e82fe0b76a39a569ea5e095d1f9443ce0f3 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Sat, 4 May 2024 13:33:16 -0700 Subject: [PATCH] fix mask shift --- open_lm/utils/transformers/hf_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index b7891136..d50936ec 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -140,7 +140,7 @@ def forward( shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1).to(shift_logits.device) if loss_mask is not None: - shift_mask = loss_mask[..., :-1].contiguous() + shift_mask = loss_mask[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits, shift_labels) shift_mask = torch.logical_and(shift_mask.view(-1), shift_labels != -100)