diff --git a/test/test_tensordict.py b/test/test_tensordict.py index af6a2edfb..737ff4f24 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -142,8 +142,8 @@ def device_fixture(): device = torch.get_default_device() if torch.cuda.is_available(): torch.set_default_device(torch.device("cuda:0")) - elif torch.backends.mps.is_available(): - torch.set_default_device(torch.device("mps:0")) + # elif torch.backends.mps.is_available(): + # torch.set_default_device(torch.device("mps:0")) yield torch.set_default_device(device) @@ -1468,8 +1468,8 @@ def check_meta(tensor): if torch.cuda.is_available(): device = "cuda:0" - elif torch.backends.mps.is_available(): - device = "mps:0" + # elif torch.backends.mps.is_available(): + # device = "mps:0" else: pytest.skip("no device to test") device_state_dict = TensorDict.load(tmpdir, device=device) @@ -1717,8 +1717,8 @@ def test_no_batch_size(self): def test_non_blocking(self): if torch.cuda.is_available(): device = "cuda" - elif torch.backends.mps.is_available(): - device = "mps" + # elif torch.backends.mps.is_available(): + # device = "mps" else: pytest.skip("No device found") for _ in range(10): @@ -1792,9 +1792,9 @@ def test_non_blocking_single_sync(self, _path_td_sync): TensorDict(td_dict, device="cpu") assert _SYNC_COUNTER == 0 - if torch.backends.mps.is_available(): - device = "mps" - elif torch.cuda.is_available(): + # if torch.backends.mps.is_available(): + # device = "mps" + if torch.cuda.is_available(): device = "cuda" else: device = None @@ -9857,7 +9857,8 @@ def check_weakref_count(weakref_list, expected): assert count == expected, {id(ref()) for ref in weakref_list} @pytest.mark.skipif( - not torch.cuda.is_available() and not torch.backends.mps.is_available(), + not torch.cuda.is_available(), + # and not torch.backends.mps.is_available(), reason="a device is required.", ) def test_cached_data_lock_device(self):