Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 9fa66cb commit ab694e6
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def device_fixture():
device = torch.get_default_device()
if torch.cuda.is_available():
torch.set_default_device(torch.device("cuda:0"))
elif torch.backends.mps.is_available():
torch.set_default_device(torch.device("mps:0"))
# elif torch.backends.mps.is_available():
# torch.set_default_device(torch.device("mps:0"))
yield
torch.set_default_device(device)

Expand Down Expand Up @@ -1468,8 +1468,8 @@ def check_meta(tensor):

if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available():
device = "mps:0"
# elif torch.backends.mps.is_available():
# device = "mps:0"
else:
pytest.skip("no device to test")
device_state_dict = TensorDict.load(tmpdir, device=device)
Expand Down Expand Up @@ -1717,8 +1717,8 @@ def test_no_batch_size(self):
def test_non_blocking(self):
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
pytest.skip("No device found")
for _ in range(10):
Expand Down Expand Up @@ -1792,9 +1792,9 @@ def test_non_blocking_single_sync(self, _path_td_sync):
TensorDict(td_dict, device="cpu")
assert _SYNC_COUNTER == 0

if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
# if torch.backends.mps.is_available():
# device = "mps"
if torch.cuda.is_available():
device = "cuda"
else:
device = None
Expand Down Expand Up @@ -9857,7 +9857,8 @@ def check_weakref_count(weakref_list, expected):
assert count == expected, {id(ref()) for ref in weakref_list}

@pytest.mark.skipif(
not torch.cuda.is_available() and not torch.backends.mps.is_available(),
not torch.cuda.is_available(),
# and not torch.backends.mps.is_available(),
reason="a device is required.",
)
def test_cached_data_lock_device(self):
Expand Down

0 comments on commit ab694e6

Please sign in to comment.