Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 19, 2024
2 parents 992e37b + 70ebf56 commit c4d8b45
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

@pytest.fixture(scope="module", autouse=True)
def set_default_device():
cur_device = torch.get_default_device()
cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
yield
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
"low": torch.as_tensor(action_spec.space.low, device=device),
"high": torch.as_tensor(action_spec.space.high, device=device),
"tanh_loc": False,
"safe_tanh": not cfg.loss.compile,
"safe_tanh": not cfg.compile.compile,
},
default_interaction_type=ExplorationType.RANDOM,
)
Expand Down
4 changes: 3 additions & 1 deletion torchrl/envs/libs/vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,9 @@ def _build_env(
num_envs=num_envs,
device=self.device
if self.device is not None
else torch.get_default_device(),
else getattr(
torch, "get_default_device", lambda: torch.device("cpu")
)(),
continuous_actions=continuous_actions,
max_steps=max_steps,
seed=seed,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def _get_default_device(net):
for p in net.parameters():
return p.device
else:
return torch.get_default_device()
return getattr(torch, "get_default_device", lambda: torch.device("cpu"))()


def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
):
super().__init__()
if device is None:
device = torch.get_default_device()
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
# this is saved for tracking only and should not be used to cast anything else than buffers during
# init.
self._device = device
Expand Down

0 comments on commit c4d8b45

Please sign in to comment.