Skip to content

Commit

Permalink
[BugFix] Fix failing tests
Browse files Browse the repository at this point in the history
ghstack-source-id: d17d760c200aeffc0f648aa770b36b23a23cc604
Pull Request resolved: #2582
  • Loading branch information
vmoens committed Nov 19, 2024
1 parent 408cf7d commit 9969144
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 65 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ jobs:
REF_TYPE=${{ github.ref_type }}
REF_NAME=${{ github.ref_name }}
apt-get update
apt-get install rsync -y
if [[ "${REF_TYPE}" == branch ]]; then
if [[ "${REF_NAME}" == main ]]; then
Expand Down
4 changes: 2 additions & 2 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def get_available_devices():
def get_default_devices():
num_cuda = torch.cuda.device_count()
if num_cuda == 0:
if torch.mps.is_available():
return [torch.device("mps:0")]
return [torch.device("cpu")]
elif num_cuda == 1:
return [torch.device("cuda:0")]
elif torch.mps.is_available():
return [torch.device("mps:0")]
else:
# then run on all devices
return get_available_devices()
Expand Down
2 changes: 1 addition & 1 deletion test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
("data", "sample_log_prob"),
],
)
def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=3):
def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=1):
env = NestedCountingEnv(nested_dim=nested_dim)
action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1)
policy_module = TensorDictModule(
Expand Down
31 changes: 14 additions & 17 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def test_ou(
self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
):
torch.manual_seed(seed)
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
net = nn.Sequential(
nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor()
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
action_spec = Bounded(-torch.ones(d_act), torch.ones(d_act), (d_act,))
Expand All @@ -252,13 +252,13 @@ def test_ou(
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
)

if interface == "module":
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device)
exploratory_policy = TensorDictSequential(policy, ou)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device)
ou = exploratory_policy

tensordict = TensorDict(
Expand Down Expand Up @@ -338,10 +338,10 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0

if interface == "module":
exploratory_policy = TensorDictSequential(
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device)
)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device)
exploratory_policy(env.reset())
collector = SyncDataCollector(
create_env_fn=env,
Expand Down Expand Up @@ -456,10 +456,10 @@ def test_additivegaussian_sd(
device=device,
)
if interface == "module":
exploratory_policy = AdditiveGaussianModule(action_spec).to(device)
exploratory_policy = AdditiveGaussianModule(action_spec, device=device)
else:
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
net = nn.Sequential(
nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor()
)
module = SafeModule(
net,
Expand All @@ -473,10 +473,10 @@ def test_additivegaussian_sd(
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
)
given_spec = action_spec if spec_origin == "spec" else None
exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to(
device
exploratory_policy = AdditiveGaussianWrapper(
policy, spec=given_spec, device=device
)
if spec_origin is not None:
sigma_init = (
Expand Down Expand Up @@ -727,10 +727,7 @@ def test_gsde(
@pytest.mark.parametrize("std", [1, 2])
@pytest.mark.parametrize("sigma_init", [None, 1.5, 3])
@pytest.mark.parametrize("learn_sigma", [False, True])
@pytest.mark.parametrize(
"device",
[torch.device("cuda:0") if torch.cuda.device_count() else torch.device("cpu")],
)
@pytest.mark.parametrize("device", get_default_devices())
def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_sigma):
torch.manual_seed(0)
state = torch.randn(10000, *state_dim, device=device) * std + mean
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,7 +2076,10 @@ def test_transform_rb(self, rbclass):
):
td = rb.sample(10)

@retry(AssertionError, tries=10, delay=0)
def test_collector_match(self):
torch.manual_seed(0)

# The counter in the collector should match the one from the transform
t = TrajCounter()

Expand Down
9 changes: 9 additions & 0 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,15 @@ def __init__(
event_shape = param.shape[-1:]
super().__init__(batch_shape=batch_shape, event_shape=event_shape)

def expand(self, batch_shape: torch.Size, _instance=None):
if self.batch_shape != tuple(batch_shape):
return type(self)(
self.param.expand((*batch_shape, *self.event_shape)),
atol=self.atol,
rtol=self.rtol,
)
return self

def update(self, param):
self.param = param

Expand Down
88 changes: 57 additions & 31 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings
from typing import Optional, Union

Expand Down Expand Up @@ -232,6 +234,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper):
is set to False but the spec is passed, the projection will still
happen.
Default is True.
device (torch.device, optional): the device where the buffers have to be stored.
.. note::
Once an environment has been wrapped in :class:`AdditiveGaussianWrapper`, it is
Expand All @@ -255,22 +258,30 @@ def __init__(
action_key: Optional[NestedKey] = "action",
spec: Optional[TensorSpec] = None,
safe: Optional[bool] = True,
device: torch.device | None = None,
):
warnings.warn(
"AdditiveGaussianWrapper is deprecated and will be removed "
"in v0.7. Please use torchrl.modules.AdditiveGaussianModule "
"instead.",
category=DeprecationWarning,
)
if device is None and hasattr(policy, "parameters"):
for p in policy.parameters():
device = p.device
break

super().__init__(policy)
if sigma_end > sigma_init:
raise RuntimeError("sigma should decrease over time or be constant")
self.register_buffer("sigma_init", torch.tensor([sigma_init]))
self.register_buffer("sigma_end", torch.tensor([sigma_end]))
self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device))
self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device))
self.annealing_num_steps = annealing_num_steps
self.register_buffer("mean", torch.tensor([mean]))
self.register_buffer("std", torch.tensor([std]))
self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32))
self.register_buffer("mean", torch.tensor([mean], device=device))
self.register_buffer("std", torch.tensor([std], device=device))
self.register_buffer(
"sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device)
)
self.action_key = action_key
self.out_keys = list(self.td_module.out_keys)
if action_key not in self.out_keys:
Expand Down Expand Up @@ -312,18 +323,17 @@ def step(self, frames: int = 1) -> None:
for _ in range(frames):
self.sigma.data.copy_(
torch.maximum(
self.sigma_end(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
),
)
self.sigma_end,
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps,
),
)

def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
sigma = self.sigma
noise = torch.normal(
mean=torch.ones(action.shape) * self.mean,
std=torch.ones(action.shape) * self.std,
mean=self.mean.expand(action.shape),
std=self.std.expand(action.shape),
).to(action.device)
action = action + noise * sigma
spec = self.spec
Expand Down Expand Up @@ -372,6 +382,7 @@ class AdditiveGaussianModule(TensorDictModuleBase):
safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space
given the :obj:`TensorSpec.project` heuristic.
default: True
device (torch.device, optional): the device where the buffers have to be stored.
.. note::
It is
Expand All @@ -394,6 +405,7 @@ def __init__(
*,
action_key: Optional[NestedKey] = "action",
safe: bool = True,
device: torch.device | None = None,
):
if not isinstance(sigma_init, float):
warnings.warn("eps_init should be a float.")
Expand All @@ -405,12 +417,14 @@ def __init__(

super().__init__()

self.register_buffer("sigma_init", torch.tensor([sigma_init]))
self.register_buffer("sigma_end", torch.tensor([sigma_end]))
self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device))
self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device))
self.annealing_num_steps = annealing_num_steps
self.register_buffer("mean", torch.tensor([mean]))
self.register_buffer("std", torch.tensor([std]))
self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32))
self.register_buffer("mean", torch.tensor([mean], device=device))
self.register_buffer("std", torch.tensor([std], device=device))
self.register_buffer(
"sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device)
)

if spec is not None:
if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
Expand Down Expand Up @@ -449,8 +463,8 @@ def step(self, frames: int = 1) -> None:
def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
sigma = self.sigma
noise = torch.normal(
mean=torch.ones(action.shape) * self.mean,
std=torch.ones(action.shape) * self.std,
mean=self.mean.expand(action.shape),
std=self.std.expand(action.shape),
).to(action.device)
action = action + noise * sigma
spec = self.spec[self.action_key]
Expand Down Expand Up @@ -530,6 +544,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper):
safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space
given the :obj:`TensorSpec.project` heuristic.
default: True
device (torch.device, optional): the device where the buffers have to be stored.
Examples:
>>> import torch
Expand Down Expand Up @@ -573,13 +588,18 @@ def __init__(
spec: TensorSpec = None,
safe: bool = True,
key: Optional[NestedKey] = None,
device: torch.device | None = None,
):
warnings.warn(
"OrnsteinUhlenbeckProcessWrapper is deprecated and will be removed "
"in v0.7. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule "
"instead.",
category=DeprecationWarning,
)
if device is None and hasattr(policy, "parameters"):
for p in policy.parameters():
device = p.device
break
if key is not None:
action_key = key
warnings.warn(
Expand All @@ -596,15 +616,17 @@ def __init__(
n_steps_annealing=n_steps_annealing,
key=action_key,
)
self.register_buffer("eps_init", torch.tensor([eps_init]))
self.register_buffer("eps_end", torch.tensor([eps_end]))
self.register_buffer("eps_init", torch.tensor([eps_init], device=device))
self.register_buffer("eps_end", torch.tensor([eps_end], device=device))
if self.eps_end > self.eps_init:
raise ValueError(
"eps should decrease over time or be constant, "
f"got eps_init={eps_init} and eps_end={eps_end}"
)
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))
self.register_buffer(
"eps", torch.tensor([eps_init], dtype=torch.float32, device=device)
)
self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys
self.is_init_key = is_init_key
noise_key = self.ou.noise_key
Expand Down Expand Up @@ -746,6 +768,7 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase):
is set to False but the spec is passed, the projection will still
happen.
Default is True.
device (torch.device, optional): the device where the buffers have to be stored.
Examples:
>>> import torch
Expand Down Expand Up @@ -782,13 +805,14 @@ def __init__(
mu: float = 0.0,
sigma: float = 0.2,
dt: float = 1e-2,
x0: Optional[Union[torch.Tensor, np.ndarray]] = None,
sigma_min: Optional[float] = None,
x0: torch.Tensor | np.ndarray | None = None,
sigma_min: float | None = None,
n_steps_annealing: int = 1000,
*,
action_key: Optional[NestedKey] = "action",
is_init_key: Optional[NestedKey] = "is_init",
action_key: NestedKey = "action",
is_init_key: NestedKey = "is_init",
safe: bool = True,
device: torch.device | None = None,
):
super().__init__()

Expand All @@ -803,15 +827,17 @@ def __init__(
key=action_key,
)

self.register_buffer("eps_init", torch.tensor([eps_init]))
self.register_buffer("eps_end", torch.tensor([eps_end]))
self.register_buffer("eps_init", torch.tensor([eps_init], device=device))
self.register_buffer("eps_end", torch.tensor([eps_end], device=device))
if self.eps_end > self.eps_init:
raise ValueError(
"eps should decrease over time or be constant, "
f"got eps_init={eps_init} and eps_end={eps_end}"
)
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))
self.register_buffer(
"eps", torch.tensor([eps_init], dtype=torch.float32, device=device)
)

self.in_keys = [self.ou.key]
self.out_keys = [self.ou.key] + self.ou.out_keys
Expand Down Expand Up @@ -946,8 +972,8 @@ def _make_noise_pair(
noise = tensordict.get(self.noise_key).clone()
steps = tensordict.get(self.steps_key).clone()
if is_init is not None:
noise = torch.masked_fill(noise, is_init, 0)
steps = torch.masked_fill(steps, is_init, 0)
noise = torch.masked_fill(noise, expand_right(is_init, noise.shape), 0)
steps = torch.masked_fill(steps, expand_right(is_init, steps.shape), 0)
return noise, steps

def add_sample(
Expand Down
Loading

0 comments on commit 9969144

Please sign in to comment.