Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Sep 4, 2024
1 parent f31a1a3 commit a45dc35
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
lang_losses_list = list(lang_losses.keys())

# Compute losses
if isinstance(outputs[0], torch.Tensor):
if len(outputs) > 0 and isinstance(outputs[0], torch.Tensor):
# Multilingual losses
for loss, lang_code in zip(outputs, lang_codes):
lang_losses[lang_losses_list[lang_code]].append(loss)
Expand All @@ -703,7 +703,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
if not lang_losses[
lang
]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation
lang_losses[lang] = torch.tensor(-1, dtype=torch.float32)
lang_losses[lang] = torch.tensor(-1, dtype=torch.float32, device="cuda")
else: # If we have at least 1 loss from a given language --> compute local language loss mean
lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang]))

Expand Down

0 comments on commit a45dc35

Please sign in to comment.