From 4488c2554cf53ef77a3f90e34252c2529355e1ce Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 7 Apr 2024 17:08:15 +0200 Subject: [PATCH] [BugFix] Fix MPS sync in device transform (#2061) --- torchrl/envs/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e25ed08ebed..b83de8b71f8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3805,7 +3805,7 @@ def __init__( if torch.cuda.is_available(): self._sync_device = torch.cuda.synchronize elif torch.backends.mps.is_available(): - self._sync_device = torch.cuda.synchronize + self._sync_device = torch.mps.synchronize elif device.type == "cpu": self._sync_device = _do_nothing else: @@ -3901,7 +3901,7 @@ def _sync_orig_device(self): if torch.cuda.is_available(): self._sync_orig_device_val = torch.cuda.synchronize elif torch.backends.mps.is_available(): - self._sync_orig_device_val = torch.cuda.synchronize + self._sync_orig_device_val = torch.mps.synchronize elif device.type == "cpu": self._sync_orig_device_val = _do_nothing else: