From 016d5dd19b38d85a76196feeaa8e81159a6baa51 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 20:06:42 -0800 Subject: [PATCH] [Feature] TD3 compatibility with compile ghstack-source-id: fb94307557f2b8604403b48211e3da6fb2139e28 Pull Request resolved: https://github.com/pytorch/rl/pull/2658 --- sota-implementations/cql/cql_online.py | 6 +- .../cql/discrete_cql_online.py | 6 +- sota-implementations/crossq/crossq.py | 6 +- sota-implementations/ddpg/ddpg.py | 6 +- .../discrete_sac/discrete_sac.py | 6 +- sota-implementations/iql/discrete_iql.py | 6 +- sota-implementations/iql/iql_online.py | 6 +- sota-implementations/sac/config.yaml | 2 +- sota-implementations/sac/sac.py | 6 +- sota-implementations/td3/config.yaml | 7 +- sota-implementations/td3/td3.py | 200 ++++--- sota-implementations/td3/utils.py | 130 ++--- test/test_specs.py | 536 ++++++++---------- torchrl/data/replay_buffers/replay_buffers.py | 8 + torchrl/data/tensor_specs.py | 67 ++- .../modules/tensordict_module/exploration.py | 5 +- 16 files changed, 501 insertions(+), 502 deletions(-) diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index b61556874c3..03bdf6a493f 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -159,11 +159,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb frames_per_batch = cfg.collector.frames_per_batch evaluation_interval = cfg.logger.log_interval diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index e6a710f1f4b..35238c5c6ab 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -140,11 +140,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index 5f6d762d644..07de3e26175 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -179,11 +179,7 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index 9d06dc2ff75..6e2a749c3f1 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -145,11 +145,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb frames_per_batch = cfg.collector.frames_per_batch eval_iter = cfg.logger.eval_iter diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index a5dad120a60..b7910c4e578 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -144,11 +144,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 805cfc6e23d..e56661acf0c 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -148,11 +148,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 28b35099286..7ec2a30dfd9 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -145,11 +145,7 @@ def update(sampled_tensordict): collected_frames = 0 init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index a1ecb90aeba..d6cb09382aa 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -13,7 +13,7 @@ collector: frames_per_batch: 1000 init_env_steps: 1000 device: - env_per_collector: 1 + env_per_collector: 8 reset_at_each_iter: False # replay buffer diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index b97fed3091c..a1ec631fe39 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -143,11 +143,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index 5bdf22ea6fa..31fa52b72f3 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -14,7 +14,7 @@ collector: frames_per_batch: 1000 reset_at_each_iter: False device: - env_per_collector: 1 + env_per_collector: 8 num_workers: 1 # replay buffer @@ -52,3 +52,8 @@ logger: mode: online eval_iter: 25000 video: False + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 70333f56cd9..bcbe6b879da 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -12,14 +12,16 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import compile_with_warmup, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -36,6 +38,9 @@ ) +torch.set_float32_matmul_precision("high") + + @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 device = cfg.network.device @@ -44,7 +49,8 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device("cuda:0") else: device = torch.device("cpu") - device = torch.device(device) + else: + device = torch.device(device) # Create logger exp_name = generate_exp_name("TD3", cfg.logger.exp_name) @@ -67,7 +73,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg, logger=logger) + train_env, eval_env = make_environment(cfg, logger=logger, device=device) # Create agent model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device) @@ -75,8 +81,23 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create TD3 loss loss_module, target_net_updater = make_loss_module(cfg, model) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy) + collector = make_collector( + cfg, + train_env, + exploration_policy, + compile_mode=compile_mode, + device=device, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -84,94 +105,111 @@ def main(cfg: "DictConfig"): # noqa: F821 prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, scratch_dir=cfg.replay_buffer.scratch_dir, - device="cpu", + device=device, + compile=bool(compile_mode), ) # Create optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) + prb = cfg.replay_buffer.prb + + def update(sampled_tensordict, update_actor, prb=prb): + + # Compute loss + q_loss, *_ = loss_module.value_loss(sampled_tensordict) + + # Update critic + q_loss.backward() + optimizer_critic.step() + optimizer_critic.zero_grad(set_to_none=True) + + # Update actor + if update_actor: + actor_loss, *_ = loss_module.actor_loss(sampled_tensordict) + + actor_loss.backward() + optimizer_actor.step() + optimizer_actor.zero_grad(set_to_none=True) + + # Update target params + target_net_updater.step() + else: + actor_loss = q_loss.new_zeros(()) + + return q_loss.detach(), actor_loss.detach() + + if cfg.compile.compile: + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) delayed_updates = cfg.optim.policy_update_delay - prb = cfg.replay_buffer.prb eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch update_counter = 0 - sampling_start = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start - exploration_policy[1].step(tensordict.numel()) + collector_iter = iter(collector) + total_iter = len(collector) + + for _ in range(total_iter): + timeit.printevery(num_prints=1000, total_count=total_iter, erase=True) + + with timeit("collect"): + tensordict = next(collector_iter) # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) - - tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + pbar.update(current_frames) + + with timeit("rb - extend"): + # Add to replay buffer + tensordict = tensordict.reshape(-1) + replay_buffer.extend(tensordict) + collected_frames += current_frames - # Optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - ( - actor_losses, - q_losses, - ) = ([], []) - for _ in range(num_updates): - - # Update actor every delayed_updates - update_counter += 1 - update_actor = update_counter % delayed_updates == 0 - - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - q_loss, *_ = loss_module.value_loss(sampled_tensordict) - - # Update critic - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - q_losses.append(q_loss.item()) - - # Update actor - if update_actor: - actor_loss, *_ = loss_module.actor_loss(sampled_tensordict) - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - actor_losses.append(actor_loss.item()) - - # Update target params - target_net_updater.step() - - # Update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) - - training_time = time.time() - training_start + with timeit("train"): + # Optimization steps + if collected_frames >= init_random_frames: + ( + actor_losses, + q_losses, + ) = ([], []) + for _ in range(num_updates): + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample() + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + q_loss, actor_loss = update(sampled_tensordict, update_actor) + + # Update priority + if prb: + with timeit("rb - priority"): + replay_buffer.update_priority(sampled_tensordict) + + q_losses.append(q_loss.clone()) + if update_actor: + actor_losses.append(actor_loss.clone()) + episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -183,22 +221,21 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log = {} if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][episode_end] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + metrics_to_log["train/reward"] = episode_rewards.mean() + metrics_to_log["train/episode_length"] = episode_length.sum() / len( episode_length ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) + metrics_to_log["train/q_loss"] = torch.stack(q_losses).mean() if update_actor: - metrics_to_log["train/a_loss"] = np.mean(actor_losses) - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + metrics_to_log["train/a_loss"] = torch.stack(actor_losses).mean() # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, exploration_policy, @@ -206,22 +243,17 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time if logger is not None: + metrics_to_log.update(timeit.todict(prefix="time")) log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index df81a522b3c..9562da65450 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -9,12 +9,12 @@ from contextlib import nullcontext import torch -from tensordict.nn import TensorDictSequential +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn, optim from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage from torchrl.envs import ( CatTensors, Compose, @@ -29,14 +29,7 @@ ) from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import ( - AdditiveGaussianModule, - MLP, - SafeModule, - SafeSequential, - TanhModule, - ValueOperator, -) +from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator from torchrl.objectives import SoftUpdate from torchrl.objectives.td3 import TD3Loss @@ -82,13 +75,14 @@ def apply_env_transforms(env, max_episode_steps): return transformed_env -def make_environment(cfg, logger=None): +def make_environment(cfg, logger, device): """Make environments for training and evaluation.""" partial = functools.partial(env_maker, cfg=cfg) parallel_env = ParallelEnv( cfg.collector.env_per_collector, EnvCreator(partial), serial_for_single=True, + device=device, ) parallel_env.set_seed(cfg.env.seed) @@ -102,9 +96,10 @@ def make_environment(cfg, logger=None): ) eval_env = TransformedEnv( ParallelEnv( - cfg.collector.env_per_collector, + 1, EnvCreator(partial), serial_for_single=True, + device=device, ), trsf_clone, ) @@ -116,14 +111,11 @@ def make_environment(cfg, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector(cfg, train_env, actor_model_explore, compile_mode, device): """Make collector.""" - device = cfg.collector.device - if device in ("", None): - if torch.cuda.is_available(): - device = torch.device("cuda:0") - else: - device = torch.device("cpu") + collector_device = cfg.collector.device + if collector_device in ("", None): + collector_device = device collector = SyncDataCollector( train_env, actor_model_explore, @@ -131,49 +123,60 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=device, + device=collector_device, + compile_policy={"mode": compile_mode} if compile_mode else False, + cudagraph_policy=cfg.compile.cudagraphs, ) collector.set_seed(cfg.env.seed) return collector def make_replay_buffer( - batch_size, - prb=False, - buffer_size=1000000, - scratch_dir=None, - device="cpu", - prefetch=3, + batch_size: int, + prb: bool = False, + buffer_size: int = 1000000, + scratch_dir: str | None = None, + device: torch.device = "cpu", + prefetch: int = 3, + compile: bool = False, ): - with ( - tempfile.TemporaryDirectory() - if scratch_dir is None - else nullcontext(scratch_dir) - ) as scratch_dir: + if compile: + prefetch = 0 + if scratch_dir in ("", None): + ctx = nullcontext(None) + elif scratch_dir == "temp": + ctx = tempfile.TemporaryDirectory() + else: + ctx = nullcontext(scratch_dir) + with ctx as scratch_dir: + storage_cls = ( + functools.partial(LazyTensorStorage, device=device, compilable=compile) + if not scratch_dir + else functools.partial( + LazyMemmapStorage, device="cpu", scratch_dir=scratch_dir + ) + ) + if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( alpha=0.7, beta=0.5, pin_memory=False, prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=scratch_dir, - device=device, - ), + storage=storage_cls(buffer_size), batch_size=batch_size, + compilable=compile, ) else: replay_buffer = TensorDictReplayBuffer( pin_memory=False, prefetch=prefetch, - storage=LazyMemmapStorage( - buffer_size, - scratch_dir=scratch_dir, - device=device, - ), + storage=storage_cls(buffer_size), batch_size=batch_size, + compilable=compile, ) + if scratch_dir: + replay_buffer.append_transform(lambda td: td.to(device)) return replay_buffer @@ -186,26 +189,21 @@ def make_td3_agent(cfg, train_env, eval_env, device): """Make TD3 agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] - actor_net_kwargs = { - "num_cells": cfg.network.hidden_sizes, - "out_features": action_spec.shape[-1], - "activation_class": get_activation(cfg), - } - - actor_net = MLP(**actor_net_kwargs) + action_spec = train_env.action_spec_unbatched.to(device) + actor_net = MLP( + num_cells=cfg.network.hidden_sizes, + out_features=action_spec.shape[-1], + activation_class=get_activation(cfg), + device=device, + ) in_keys_actor = in_keys - actor_module = SafeModule( + actor_module = TensorDictModule( actor_net, in_keys=in_keys_actor, - out_keys=[ - "param", - ], + out_keys=["param"], ) - actor = SafeSequential( + actor = TensorDictSequential( actor_module, TanhModule( in_keys=["param"], @@ -215,14 +213,11 @@ def make_td3_agent(cfg, train_env, eval_env, device): ) # Define Critic Network - qvalue_net_kwargs = { - "num_cells": cfg.network.hidden_sizes, - "out_features": 1, - "activation_class": get_activation(cfg), - } - qvalue_net = MLP( - **qvalue_net_kwargs, + num_cells=cfg.network.hidden_sizes, + out_features=1, + activation_class=get_activation(cfg), + device=device, ) qvalue = ValueOperator( @@ -230,20 +225,17 @@ def make_td3_agent(cfg, train_env, eval_env, device): module=qvalue_net, ) - model = nn.ModuleList([actor, qvalue]).to(device) + model = nn.ModuleList([actor, qvalue]) # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = eval_env.fake_tensordict() td = td.to(device) for net in model: net(td) - del td - eval_env.close() - # Exploration wrappers: actor_model_explore = TensorDictSequential( - model[0], + actor, AdditiveGaussianModule( sigma_init=1, sigma_end=1, diff --git a/test/test_specs.py b/test/test_specs.py index 5334281f0ee..a75ff0352c7 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -59,316 +59,278 @@ ) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -def test_bounded(dtype): - torch.manual_seed(0) - np.random.seed(0) - for _ in range(100): - bounds = torch.randn(2).sort()[0] - ts = Bounded(bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype) - _dtype = dtype - if dtype is None: - _dtype = torch.get_default_dtype() - - r = ts.rand() - assert ts.is_in(r) - assert r.dtype is _dtype - ts.is_in(ts.encode(bounds.mean())) - ts.is_in(ts.encode(bounds.mean().item())) - assert (ts.encode(ts.to_numpy(r)) == r).all() - - -@pytest.mark.parametrize("cls", [OneHot, Categorical]) -def test_discrete(cls): - torch.manual_seed(0) - np.random.seed(0) +class TestRanges: + @pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.float64, None] + ) + def test_bounded(self, dtype): + torch.manual_seed(0) + np.random.seed(0) + for _ in range(100): + bounds = torch.randn(2).sort()[0] + ts = Bounded( + bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype + ) + _dtype = dtype + if dtype is None: + _dtype = torch.get_default_dtype() - ts = cls(10) - for _ in range(100): - r = ts.rand() - ts.to_numpy(r) - ts.encode(torch.tensor([5])) - ts.encode(torch.tensor(5).numpy()) - ts.encode(9) - with pytest.raises(AssertionError), set_global_var( - torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True - ): - ts.encode(torch.tensor([11])) # out of bounds - assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE - assert ts.is_in(r) - assert (ts.encode(ts.to_numpy(r)) == r).all() + r = ts.rand() + assert (ts._project(r) == r).all() + assert ts.is_in(r) + assert r.dtype is _dtype + ts.is_in(ts.encode(bounds.mean())) + ts.is_in(ts.encode(bounds.mean().item())) + assert (ts.encode(ts.to_numpy(r)) == r).all() + @pytest.mark.parametrize("cls", [OneHot, Categorical]) + def test_discrete(self, cls): + torch.manual_seed(0) + np.random.seed(0) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -def test_unbounded(dtype): - torch.manual_seed(0) - np.random.seed(0) - ts = Unbounded(dtype=dtype) - - if dtype is None: - dtype = torch.get_default_dtype() - for _ in range(100): - r = ts.rand() - ts.to_numpy(r) - assert ts.is_in(r) - assert r.dtype is dtype - assert (ts.encode(ts.to_numpy(r)) == r).all() + ts = cls(10) + for _ in range(100): + r = ts.rand() + assert (ts._project(r) == r).all() + ts.to_numpy(r) + ts.encode(torch.tensor([5])) + ts.encode(torch.tensor(5).numpy()) + ts.encode(9) + with pytest.raises(AssertionError), set_global_var( + torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True + ): + ts.encode(torch.tensor([11])) # out of bounds + assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE + assert ts.is_in(r) + assert (ts.encode(ts.to_numpy(r)) == r).all() + @pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.float64, None] + ) + def test_unbounded(self, dtype): + torch.manual_seed(0) + np.random.seed(0) + ts = Unbounded(dtype=dtype) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -@pytest.mark.parametrize("shape", [[], torch.Size([3])]) -def test_ndbounded(dtype, shape): - torch.manual_seed(0) - np.random.seed(0) - - for _ in range(100): - lb = torch.rand(10) - 1 - ub = torch.rand(10) + 1 - ts = Bounded(lb, ub, dtype=dtype) - _dtype = dtype if dtype is None: - _dtype = torch.get_default_dtype() - - r = ts.rand(shape) - assert r.dtype is _dtype - assert r.shape == torch.Size([*shape, 10]) - assert (r >= lb.to(dtype)).all() and ( - r <= ub.to(dtype) - ).all(), f"{r[r <= lb] - lb.expand_as(r)[r <= lb]} -- {r[r >= ub] - ub.expand_as(r)[r >= ub]} " - ts.to_numpy(r) - assert ts.is_in(r) - ts.encode(lb + torch.rand(10) * (ub - lb)) - ts.encode((lb + torch.rand(10) * (ub - lb)).numpy()) - - if not shape: + dtype = torch.get_default_dtype() + for _ in range(100): + r = ts.rand() + assert (ts._project(r) == r).all() + ts.to_numpy(r) + assert ts.is_in(r) + assert r.dtype is dtype assert (ts.encode(ts.to_numpy(r)) == r).all() - else: - with pytest.raises(RuntimeError, match="Shape mismatch"): - ts.encode(ts.to_numpy(r)) - assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() - - with pytest.raises(AssertionError), set_global_var( - torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True - ): - ts.encode(torch.rand(10) + 3) # out of bounds - with pytest.raises(AssertionError), set_global_var( - torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True - ): - ts.to_numpy(torch.rand(10) + 3) # out of bounds - assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE + @pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.float64, None] + ) + @pytest.mark.parametrize("shape", [[], torch.Size([3])]) + def test_ndbounded(self, dtype, shape): + torch.manual_seed(0) + np.random.seed(0) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -@pytest.mark.parametrize("n", range(3, 10)) -@pytest.mark.parametrize( - "shape", - [ - [], - torch.Size( - [ - 3, - ] - ), - ], -) -def test_ndunbounded(dtype, n, shape): - torch.manual_seed(0) - np.random.seed(0) + for _ in range(100): + lb = torch.rand(10) - 1 + ub = torch.rand(10) + 1 + ts = Bounded(lb, ub, dtype=dtype) + _dtype = dtype + if dtype is None: + _dtype = torch.get_default_dtype() + + r = ts.rand(shape) + assert (ts._project(r) == r).all() + assert r.dtype is _dtype + assert r.shape == torch.Size([*shape, 10]) + assert (r >= lb.to(dtype)).all() and ( + r <= ub.to(dtype) + ).all(), f"{r[r <= lb] - lb.expand_as(r)[r <= lb]} -- {r[r >= ub] - ub.expand_as(r)[r >= ub]} " + ts.to_numpy(r) + assert ts.is_in(r) + ts.encode(lb + torch.rand(10) * (ub - lb)) + ts.encode((lb + torch.rand(10) * (ub - lb)).numpy()) + + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + + with pytest.raises(AssertionError), set_global_var( + torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True + ): + ts.encode(torch.rand(10) + 3) # out of bounds + with pytest.raises(AssertionError), set_global_var( + torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True + ): + ts.to_numpy(torch.rand(10) + 3) # out of bounds + assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE - ts = Unbounded( - shape=[ - n, - ], - dtype=dtype, + @pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.float64, None] ) + @pytest.mark.parametrize("n", range(3, 10)) + @pytest.mark.parametrize("shape", [(), torch.Size([3])]) + def test_ndunbounded(self, dtype, n, shape): + torch.manual_seed(0) + np.random.seed(0) - if dtype is None: - dtype = torch.get_default_dtype() + ts = Unbounded(shape=[n], dtype=dtype) - for _ in range(100): - r = ts.rand(shape) - assert r.shape == torch.Size( - [ - *shape, - n, - ] - ) - ts.to_numpy(r) - assert ts.is_in(r) - assert r.dtype is dtype - if not shape: - assert (ts.encode(ts.to_numpy(r)) == r).all() - else: - with pytest.raises(RuntimeError, match="Shape mismatch"): - ts.encode(ts.to_numpy(r)) - assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + if dtype is None: + dtype = torch.get_default_dtype() + for _ in range(100): + r = ts.rand(shape) + assert (ts._project(r) == r).all() + assert r.shape == torch.Size( + [ + *shape, + n, + ] + ) + ts.to_numpy(r) + assert ts.is_in(r) + assert r.dtype is dtype + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + + @pytest.mark.parametrize("n", range(3, 10)) + @pytest.mark.parametrize("shape", [(), torch.Size([3])]) + def test_binary(self, n, shape): + torch.manual_seed(0) + np.random.seed(0) -@pytest.mark.parametrize("n", range(3, 10)) -@pytest.mark.parametrize( - "shape", - [ - [], - torch.Size( - [ - 3, - ] - ), - ], -) -def test_binary(n, shape): - torch.manual_seed(0) - np.random.seed(0) - - ts = Binary(n) - for _ in range(100): - r = ts.rand(shape) - assert r.shape == torch.Size( - [ - *shape, - n, - ] - ) - assert ts.is_in(r) - assert ((r == 0) | (r == 1)).all() - if not shape: - assert (ts.encode(ts.to_numpy(r)) == r).all() - else: - with pytest.raises(RuntimeError, match="Shape mismatch"): - ts.encode(ts.to_numpy(r)) - assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + ts = Binary(n) + for _ in range(100): + r = ts.rand(shape) + assert (ts._project(r) == r).all() + assert r.shape == torch.Size([*shape, n]) + assert ts.is_in(r) + assert ((r == 0) | (r == 1)).all() + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + @pytest.mark.parametrize( + "ns", + [ + [5], + [5, 2, 3], + [4, 4, 1], + ], + ) + @pytest.mark.parametrize("shape", [(), torch.Size([3])]) + def test_mult_onehot(self, shape, ns): + torch.manual_seed(0) + np.random.seed(0) + ts = MultiOneHot(nvec=ns) + for _ in range(100): + r = ts.rand(shape) + assert (ts._project(r) == r).all() + assert r.shape == torch.Size([*shape, sum(ns)]) + assert ts.is_in(r) + assert ((r == 0) | (r == 1)).all() + rsplit = r.split(ns, dim=-1) + for _r, _n in zip(rsplit, ns): + assert (_r.sum(-1) == 1).all() + assert _r.shape[-1] == _n + categorical = ts.to_categorical(r) + assert not ts.is_in(categorical) + # assert (ts.encode(categorical) == r).all() + if not shape: + assert (ts.encode(categorical) == r).all() + else: + with pytest.raises(RuntimeError, match="is invalid for input of size"): + ts.encode(categorical) + assert (ts.expand(*shape, *ts.shape).encode(categorical) == r).all() -@pytest.mark.parametrize( - "ns", - [ + @pytest.mark.parametrize( + "ns", [ 5, + [5, 2, 3], + [4, 5, 1, 3], + [[1, 2], [3, 4]], + [[[2, 4], [3, 5]], [[4, 5], [2, 3]], [[2, 3], [3, 2]]], ], - [5, 2, 3], - [4, 4, 1], - ], -) -@pytest.mark.parametrize( - "shape", - [ - [], - torch.Size( - [ - 3, - ] - ), - ], -) -def test_mult_onehot(shape, ns): - torch.manual_seed(0) - np.random.seed(0) - ts = MultiOneHot(nvec=ns) - for _ in range(100): - r = ts.rand(shape) - assert r.shape == torch.Size( - [ - *shape, - sum(ns), - ] - ) - assert ts.is_in(r) - assert ((r == 0) | (r == 1)).all() - rsplit = r.split(ns, dim=-1) - for _r, _n in zip(rsplit, ns): - assert (_r.sum(-1) == 1).all() - assert _r.shape[-1] == _n - categorical = ts.to_categorical(r) - assert not ts.is_in(categorical) - # assert (ts.encode(categorical) == r).all() - if not shape: - assert (ts.encode(categorical) == r).all() - else: - with pytest.raises(RuntimeError, match="is invalid for input of size"): - ts.encode(categorical) - assert (ts.expand(*shape, *ts.shape).encode(categorical) == r).all() - - -@pytest.mark.parametrize( - "ns", - [ - 5, - [5, 2, 3], - [4, 5, 1, 3], - [[1, 2], [3, 4]], - [[[2, 4], [3, 5]], [[4, 5], [2, 3]], [[2, 3], [3, 2]]], - ], -) -@pytest.mark.parametrize("shape", [None, [], torch.Size([3]), torch.Size([4, 5])]) -@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.long]) -def test_multi_discrete(shape, ns, dtype): - torch.manual_seed(0) - np.random.seed(0) - ts = MultiCategorical(ns, dtype=dtype) - _real_shape = shape if shape is not None else [] - nvec_shape = torch.tensor(ns).size() - for _ in range(100): - r = ts.rand(shape) - - assert r.shape == torch.Size( - [ - *_real_shape, - *nvec_shape, - ] - ), (r.shape, ns, shape, _real_shape, nvec_shape) - assert ts.is_in(r), (r, r.shape, ns) - rand = torch.rand( - torch.Size( - [ - *_real_shape, - *nvec_shape, - ] - ) ) - projection = ts._project(rand) - - assert rand.shape == projection.shape - assert ts.is_in(projection) - if projection.ndim < 1: - projection.fill_(-1) - else: - projection[..., 0] = -1 - assert not ts.is_in(projection) - - -@pytest.mark.parametrize("n", [1, 4, 7, 99]) -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("shape", [None, [], [1], [1, 2]]) -def test_discrete_conversion(n, device, shape): - categorical = Categorical(n, device=device, shape=shape) - shape_one_hot = [n] if not shape else [*shape, n] - one_hot = OneHot(n, device=device, shape=shape_one_hot) - - assert categorical != one_hot - assert categorical.to_one_hot_spec() == one_hot - assert one_hot.to_categorical_spec() == categorical - - categorical_recon = one_hot.to_categorical(one_hot.rand(shape)) - assert categorical.is_in(categorical_recon), (categorical, categorical_recon) - one_hot_recon = categorical.to_one_hot(categorical.rand(shape)) - assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon) + @pytest.mark.parametrize("shape", [None, [], torch.Size([3]), torch.Size([4, 5])]) + @pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.long]) + def test_multi_discrete(self, shape, ns, dtype): + torch.manual_seed(0) + np.random.seed(0) + ts = MultiCategorical(ns, dtype=dtype) + _real_shape = shape if shape is not None else [] + nvec_shape = torch.tensor(ns).size() + for _ in range(100): + r = ts.rand(shape) + assert r.shape == torch.Size( + [ + *_real_shape, + *nvec_shape, + ] + ), (r.shape, ns, shape, _real_shape, nvec_shape) + assert ts.is_in(r), (r, r.shape, ns) + rand = torch.rand( + torch.Size( + [ + *_real_shape, + *nvec_shape, + ] + ) + ) + projection = ts._project(rand) -@pytest.mark.parametrize("ns", [[5], [5, 2, 3], [4, 5, 1, 3]]) -@pytest.mark.parametrize("shape", [torch.Size([3]), torch.Size([4, 5])]) -@pytest.mark.parametrize("device", get_default_devices()) -def test_multi_discrete_conversion(ns, shape, device): - categorical = MultiCategorical(ns, device=device) - one_hot = MultiOneHot(ns, device=device) - - assert categorical != one_hot - assert categorical.to_one_hot_spec() == one_hot - assert one_hot.to_categorical_spec() == categorical - - categorical_recon = one_hot.to_categorical(one_hot.rand(shape)) - assert categorical.is_in(categorical_recon), (categorical, categorical_recon) - one_hot_recon = categorical.to_one_hot(categorical.rand(shape)) - assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon) + assert rand.shape == projection.shape + assert ts.is_in(projection) + if projection.ndim < 1: + projection.fill_(-1) + else: + projection[..., 0] = -1 + assert not ts.is_in(projection) + + @pytest.mark.parametrize("n", [1, 4, 7, 99]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("shape", [None, [], [1], [1, 2]]) + def test_discrete_conversion(self, n, device, shape): + categorical = Categorical(n, device=device, shape=shape) + shape_one_hot = [n] if not shape else [*shape, n] + one_hot = OneHot(n, device=device, shape=shape_one_hot) + + assert categorical != one_hot + assert categorical.to_one_hot_spec() == one_hot + assert one_hot.to_categorical_spec() == categorical + + categorical_recon = one_hot.to_categorical(one_hot.rand(shape)) + assert categorical.is_in(categorical_recon), (categorical, categorical_recon) + one_hot_recon = categorical.to_one_hot(categorical.rand(shape)) + assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon) + + @pytest.mark.parametrize("ns", [[5], [5, 2, 3], [4, 5, 1, 3]]) + @pytest.mark.parametrize("shape", [torch.Size([3]), torch.Size([4, 5])]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_multi_discrete_conversion(self, ns, shape, device): + categorical = MultiCategorical(ns, device=device) + one_hot = MultiOneHot(ns, device=device) + + assert categorical != one_hot + assert categorical.to_one_hot_spec() == one_hot + assert one_hot.to_categorical_spec() == categorical + + categorical_recon = one_hot.to_categorical(one_hot.rand(shape)) + assert categorical.is_in(categorical_recon), (categorical, categorical_recon) + one_hot_recon = categorical.to_one_hot(categorical.rand(shape)) + assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon) @pytest.mark.parametrize("is_complete", [True, False]) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index fbb76b5a681..4ddf059d5b4 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -1094,6 +1094,9 @@ class TensorDictReplayBuffer(ReplayBuffer): .. warning:: As of now, the generator has no effect on the transforms. shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> import torch @@ -1437,6 +1440,9 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): .. warning:: As of now, the generator has no effect on the transforms. shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> import torch @@ -1512,6 +1518,7 @@ def __init__( dim_extend: int | None = None, generator: torch.Generator | None = None, shared: bool = False, + compilable: bool = False, ) -> None: if storage is None: storage = ListStorage(max_size=1_000) @@ -1530,6 +1537,7 @@ def __init__( dim_extend=dim_extend, generator=generator, shared=shared, + compilable=compilable, ) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ad29b63db04..5f724577ddd 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -44,6 +44,11 @@ from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling + DEVICE_TYPING = Union[torch.device, str, int] INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List] @@ -381,11 +386,17 @@ class ContinuousBox(Box): # We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used. @property def low(self): - return self._low.to(self.device) + low = self._low + if self.device is not None and low.device != self.device: + low = low.to(self.device) + return low @property def high(self): - return self._high.to(self.device) + high = self._high + if self.device is not None and high.device != self.device: + high = high.to(self.device) + return high def unbind(self, dim: int = 0): return tuple( @@ -396,12 +407,12 @@ def unbind(self, dim: int = 0): @low.setter def low(self, value): self.device = value.device - self._low = value.cpu() + self._low = value @high.setter def high(self, value): self.device = value.device - self._high = value.cpu() + self._high = value def __post_init__(self): self.low = self.low.clone() @@ -871,7 +882,7 @@ def project( a torch.Tensor belonging to the TensorSpec box. """ - if not self.is_in(val): + if is_compiling() or not self.is_in(val): return self._project(val) return val @@ -1509,7 +1520,9 @@ def __init__( use_register: bool = False, mask: torch.Tensor | None = None, ): - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) self.use_register = use_register space = CategoricalBox(n) if shape is None: @@ -2035,7 +2048,9 @@ def __init__( if len(kwargs): raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.") - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) if dtype is None: dtype = torch.get_default_dtype() if domain is None: @@ -2270,14 +2285,20 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: r = torch.rand(_size([*shape, *self._safe_shape]), device=interval.device) r = interval * r r = self.space.low + r - r = r.to(self.dtype).to(self.device) + if r.dtype != self.dtype: + r = r.to(self.dtype) + if self.dtype is not None and r.device != self.device: + r = r.to(self.device) return r def _project(self, val: torch.Tensor) -> torch.Tensor: - low = self.space.low.to(val.device) - high = self.space.high.to(val.device) + low = self.space.low + high = self.space.high + if self.device != val.device: + low = low.to(val.device) + high = high.to(val.device) try: - val = val.clamp_(low.item(), high.item()) + val = torch.maximum(torch.minimum(val, high), low) except ValueError: low = low.expand_as(val) high = high.expand_as(val) @@ -2630,7 +2651,9 @@ def __init__( if isinstance(shape, int): shape = _size([shape]) - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) if dtype == torch.bool: min_value = False max_value = True @@ -2687,7 +2710,9 @@ def is_in(self, val: torch.Tensor) -> bool: return val.shape == shape and val.dtype == self.dtype def _project(self, val: torch.Tensor) -> torch.Tensor: - return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) + return torch.as_tensor(val, dtype=self.dtype).reshape( + val.shape[: -self.ndim] + self.shape + ) def enumerate(self) -> Any: raise NotImplementedError("enumerate cannot be called with continuous specs.") @@ -2745,8 +2770,8 @@ def __eq__(self, other): # those specs are equivalent to a discrete spec if isinstance(other, Bounded): minval, maxval = _minmax_dtype(self.dtype) - minval = torch.as_tensor(minval).to(self.device, self.dtype) - maxval = torch.as_tensor(maxval).to(self.device, self.dtype) + minval = torch.as_tensor(minval, device=self.device, dtype=self.dtype) + maxval = torch.as_tensor(maxval, device=self.device, dtype=self.dtype) return ( Bounded( shape=self.shape, @@ -2835,7 +2860,9 @@ def __init__( mask: torch.Tensor | None = None, ): self.nvec = nvec - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) if shape is None: shape = _size((sum(nvec),)) else: @@ -3311,7 +3338,9 @@ def __init__( ): if shape is None: shape = _size([]) - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) space = CategoricalBox(n) super().__init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete" @@ -3858,7 +3887,9 @@ def __init__( if nvec.ndim < 1: nvec = nvec.unsqueeze(0) self.nvec = nvec - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) if shape is None: shape = nvec.shape else: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 6e8296a677a..da0c6dc3260 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -397,7 +397,7 @@ class AdditiveGaussianModule(TensorDictModuleBase): default: "action" safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space given the :obj:`TensorSpec.project` heuristic. - default: True + default: False device (torch.device, optional): the device where the buffers have to be stored. .. note:: @@ -420,7 +420,8 @@ def __init__( std: float = 1.0, *, action_key: Optional[NestedKey] = "action", - safe: bool = True, + # safe is already implemented because we project in the noise addition + safe: bool = False, device: torch.device | None = None, ): if not isinstance(sigma_init, float):