Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Nov 11, 2024
1 parent 10738ec commit beaab56
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions turbo_alignment/trainers/ref_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def compute_loss(

gc.collect()
torch.cuda.empty_cache()
return torch.tensor([0.0], requires_grad=True)
return torch.tensor([0.0], requires_grad=True).to(self.accelerator.device)

def prediction_step(
self,
Expand All @@ -159,6 +159,10 @@ def prediction_step(
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
if prediction_loss_only:
return torch.tensor([0.0]).detach(), None, None
return torch.tensor([0.0]).to(self.accelerator.device).detach(), None, None

return torch.tensor([0.0]).detach(), torch.tensor([0.0]), torch.tensor([0.0])
return (
torch.tensor([0.0]).to(self.accelerator.device).detach(),
torch.tensor([0.0]).to(self.accelerator.device),
torch.tensor([0.0]).to(self.accelerator.device),
)

0 comments on commit beaab56

Please sign in to comment.