Skip to content

Commit

Permalink
[BugFix] Fix MPS sync in device transform (#2061)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 7, 2024
1 parent f85da4c commit 4488c25
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3805,7 +3805,7 @@ def __init__(
if torch.cuda.is_available():
self._sync_device = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_device = torch.cuda.synchronize
self._sync_device = torch.mps.synchronize
elif device.type == "cpu":
self._sync_device = _do_nothing
else:
Expand Down Expand Up @@ -3901,7 +3901,7 @@ def _sync_orig_device(self):
if torch.cuda.is_available():
self._sync_orig_device_val = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_orig_device_val = torch.cuda.synchronize
self._sync_orig_device_val = torch.mps.synchronize
elif device.type == "cpu":
self._sync_orig_device_val = _do_nothing
else:
Expand Down

0 comments on commit 4488c25

Please sign in to comment.