diff --git a/turbo_alignment/trainers/ref_model.py b/turbo_alignment/trainers/ref_model.py index 4dd5d93..e441b75 100755 --- a/turbo_alignment/trainers/ref_model.py +++ b/turbo_alignment/trainers/ref_model.py @@ -149,7 +149,7 @@ def compute_loss( gc.collect() torch.cuda.empty_cache() - return torch.tensor([0.0], requires_grad=True) + return torch.tensor([0.0], requires_grad=True).to(self.accelerator.device) def prediction_step( self, @@ -159,6 +159,10 @@ def prediction_step( ignore_keys: list[str] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: if prediction_loss_only: - return torch.tensor([0.0]).detach(), None, None + return torch.tensor([0.0]).to(self.accelerator.device).detach(), None, None - return torch.tensor([0.0]).detach(), torch.tensor([0.0]), torch.tensor([0.0]) + return ( + torch.tensor([0.0]).to(self.accelerator.device).detach(), + torch.tensor([0.0]).to(self.accelerator.device), + torch.tensor([0.0]).to(self.accelerator.device), + )