diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 2acf0b92e59..9723af0da8f 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -49,7 +49,7 @@ from torchrl.collectors.utils import split_trajectories from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING -from torchrl.envs.common import EnvBase +from torchrl.envs.common import _do_nothing, EnvBase from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( _aggregate_end_of_traj, @@ -472,8 +472,45 @@ def __init__( ) self.storing_device = storing_device + if self.storing_device is not None and self.storing_device.type != "cuda": + # Cuda handles sync + if torch.cuda.is_available(): + self._sync_storage = torch.cuda.synchronize + elif torch.backends.mps.is_available(): + self._sync_storage = torch.mps.synchronize + elif self.storing_device.type == "cpu": + self._sync_storage = _do_nothing + else: + raise RuntimeError("Non supported device") + else: + self._sync_storage = _do_nothing + self.env_device = env_device + if self.env_device is not None and self.env_device.type != "cuda": + # Cuda handles sync + if torch.cuda.is_available(): + self._sync_env = torch.cuda.synchronize + elif torch.backends.mps.is_available(): + self._sync_env = torch.mps.synchronize + elif self.env_device.type == "cpu": + self._sync_env = _do_nothing + else: + raise RuntimeError("Non supported device") + else: + self._sync_env = _do_nothing self.policy_device = policy_device + if self.policy_device is not None and self.policy_device.type != "cuda": + # Cuda handles sync + if torch.cuda.is_available(): + self._sync_policy = torch.cuda.synchronize + elif torch.backends.mps.is_available(): + self._sync_policy = torch.mps.synchronize + elif self.policy_device.type == "cpu": + self._sync_policy = _do_nothing + else: + raise RuntimeError("Non supported device") + else: + self._sync_policy = _do_nothing self.device = device # Check if we need to cast things from device to device # If the policy has a None device and the env too, no need to cast (we don't know @@ -503,7 +540,7 @@ def __init__( if self.env_device: self.env: EnvBase = self.env.to(self.env_device) elif self.env.device is not None: - # we we did not receive an env device, we use the device of the env + # we did not receive an env device, we use the device of the env self.env_device = self.env.device # If the storing device is not the same as the policy device, we have @@ -915,6 +952,7 @@ def rollout(self) -> TensorDictBase: policy_input = self._shuttle.to( self.policy_device, non_blocking=True ) + self._sync_policy() elif self.policy_device is None: # we know the tensordict has a device otherwise we would not be here # we can pass this, clear_device_ must have been called earlier @@ -933,6 +971,7 @@ def rollout(self) -> TensorDictBase: if self._cast_to_env_device: if self.env_device is not None: env_input = self._shuttle.to(self.env_device, non_blocking=True) + self._sync_env() elif self.env_device is None: # we know the tensordict has a device otherwise we would not be here # we can pass this, clear_device_ must have been called earlier @@ -954,6 +993,7 @@ def rollout(self) -> TensorDictBase: tensordicts.append( self._shuttle.to(self.storing_device, non_blocking=True) ) + self._sync_storage() else: tensordicts.append(self._shuttle) @@ -1000,16 +1040,6 @@ def rollout(self) -> TensorDictBase: ) return self._final_rollout - @staticmethod - def _update_device_wise(tensor0, tensor1): - # given 2 tensors, returns tensor0 if their identity matches, - # or a copy of tensor1 on the device of tensor0 otherwise - if tensor1 is None or tensor1 is tensor0: - return tensor0 - if tensor1.device == tensor0.device: - return tensor1 - return tensor1.to(tensor0.device, non_blocking=True) - @torch.no_grad() def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 383bce0386d..693fac6daf2 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1141,7 +1141,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any: # to be deprecated in v0.4 def map_device(tensor): if tensor.device != self.device: - return tensor.to(self.device, non_blocking=True) + return tensor.to(self.device, non_blocking=False) return tensor if is_tensor_collection(result): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 48085d21093..fd14a07377f 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -12,7 +12,7 @@ import os import weakref from collections import OrderedDict -from copy import deepcopy +from copy import copy, deepcopy from functools import wraps from multiprocessing import connection from multiprocessing.synchronize import Lock as MpLock @@ -32,7 +32,7 @@ ) from torchrl.data.tensor_specs import CompositeSpec from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING -from torchrl.envs.common import _EnvPostInit, EnvBase +from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase from torchrl.envs.env_creator import get_env_metadata # legacy @@ -174,8 +174,7 @@ class BatchedEnvBase(EnvBase): with a single worker will return a :class:`~SerialEnv` instead. This option has no effect with :class:`~SerialEnv`. Defaults to ``False``. non_blocking (bool, optional): if ``True``, device moves will be done using the - ``non_blocking=True`` option. Defaults to ``True`` for batched environments - on cuda devices, and ``False`` otherwise. + ``non_blocking=True`` option. Defaults to ``True``. mp_start_method (str, optional): the multiprocessing start method. Uses the default start method if not indicated ('spawn' by default in TorchRL if not initiated differently before first import). @@ -345,10 +344,103 @@ def __init__( def non_blocking(self): nb = self._non_blocking if nb is None: - nb = self.device is not None and self.device.type == "cuda" + nb = True self._non_blocking = nb return nb + @property + def _sync_m2w(self) -> Callable: + sync_func = self.__dict__.get("_sync_m2w_value", None) + if sync_func is None: + sync_m2w, sync_w2m = self._find_sync_values() + self.__dict__["_sync_m2w_value"] = sync_m2w + self.__dict__["_sync_w2m_value"] = sync_w2m + return sync_m2w + return sync_func + + @property + def _sync_w2m(self) -> Callable: + sync_func = self.__dict__.get("_sync_w2m_value", None) + if sync_func is None: + sync_m2w, sync_w2m = self._find_sync_values() + self.__dict__["_sync_m2w_value"] = sync_m2w + self.__dict__["_sync_w2m_value"] = sync_w2m + return sync_w2m + return sync_func + + def _find_sync_values(self): + """Returns the m2w and w2m sync values, in that order.""" + # Simplest case: everything is on the same device + worker_device = self.shared_tensordict_parent.device + self_device = self.device + if not self.non_blocking or ( + worker_device == self_device or self_device is None + ): + # even if they're both None, there is no device-to-device movement + return _do_nothing, _do_nothing + + if worker_device is None: + worker_not_main = [False] + + def find_all_worker_devices(item, worker_not_main=worker_not_main): + if hasattr(item, "device"): + worker_not_main[0] = worker_not_main[0] or ( + item.device != self_device + ) + + for td in self.shared_tensordicts: + td.apply(find_all_worker_devices, filter_empty=True) + if worker_not_main[0]: + if torch.cuda.is_available(): + worker_device = ( + torch.device("cuda") + if self_device.type != "cuda" + else torch.device("cpu") + ) + elif torch.backends.mps.is_available(): + worker_device = ( + torch.device("mps") + if self_device.type != "mps" + else torch.device("cpu") + ) + else: + raise RuntimeError("Did not find a valid worker device") + + if ( + worker_device is not None + and worker_device.type == "cuda" + and self_device is not None + and self_device.type == "cpu" + ): + return _do_nothing, _cuda_sync(worker_device) + if ( + worker_device is not None + and worker_device.type == "mps" + and self_device is not None + and self_device.type == "cpu" + ): + return _mps_sync(worker_device), _mps_sync(worker_device) + if ( + worker_device is not None + and worker_device.type == "cpu" + and self_device is not None + and self_device.type == "cuda" + ): + return _cuda_sync(self_device), _do_nothing + if ( + worker_device is not None + and worker_device.type == "cpu" + and self_device is not None + and self_device.type == "mps" + ): + return _mps_sync(self_device), _mps_sync(self_device) + + def __getstate__(self): + out = copy(self.__dict__) + out["_sync_m2w_value"] = None + out["_sync_w2m_value"] = None + return out + def _get_metadata( self, create_env_fn: List[Callable], create_env_kwargs: List[Dict] ): @@ -634,6 +726,9 @@ def _create_td(self) -> None: if not self.shared_tensordict_parent.is_memmap(): raise RuntimeError("memmap_() failed") self.shared_tensordicts = self.shared_tensordict_parent.unbind(0) + for td in self.shared_tensordicts: + td.lock_() + # we cache all the keys of the shared parent td for future use. This is # safe since the td is locked. self._cache_shared_keys = set(self.shared_tensordict_parent.keys(True, True)) @@ -695,20 +790,12 @@ def to(self, device: DEVICE_TYPING): if device == self.device: return self self._device = device - if not self.is_closed: - warn( - "Casting an open environment to another device requires closing and re-opening it. " - "This may have unexpected and unwanted effects (e.g. on seeding etc.)" - ) - # the tensordicts must be re-created on device - super().to(device) - self.close() - self.start() - else: - if self.__dict__["_input_spec"] is not None: - self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(device) - if self.__dict__["_output_spec"] is not None: - self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to(device) + self.__dict__["_sync_m2w_value"] = None + self.__dict__["_sync_w2m_value"] = None + if self.__dict__["_input_spec"] is not None: + self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(device) + if self.__dict__["_output_spec"] is not None: + self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to(device) return self @@ -796,6 +883,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: (self.num_workers,), device=self.device, dtype=torch.bool ) + tds = [] for i, _env in enumerate(self._envs): if not needs_resetting[i]: continue @@ -813,12 +901,18 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_ = tensordict_.clone(False) else: tensordict_ = None + tds.append((i, tensordict_)) + + self._sync_m2w() + for i, tensordict_ in tds: + _env = self._envs[i] _td = _env.reset(tensordict=tensordict_, **kwargs) try: self.shared_tensordicts[i].update_( _td, keys_to_update=list(self._selected_reset_keys_filt), + non_blocking=self.non_blocking, ) except RuntimeError as err: if "no_grad mode" in str(err): @@ -828,6 +922,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: "share_individual_td argument to True." ) raise + selected_output_keys = self._selected_reset_keys_filt device = self.device @@ -847,6 +942,7 @@ def select_and_clone(name, tensor): out = out.clear_device_() else: out = out.to(device, non_blocking=self.non_blocking) + self._sync_w2m() return out def _reset_proc_data(self, tensordict, tensordict_reset): @@ -862,18 +958,27 @@ def _step( ) -> TensorDict: tensordict_in = tensordict.clone(False) next_td = self.shared_tensordict_parent.get("next") + data_in = [] for i in range(self.num_workers): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device if env_device != self.device and env_device is not None: - data_in = tensordict_in[i].to( - env_device, non_blocking=self.non_blocking + data_in.append( + tensordict_in[i].to(env_device, non_blocking=self.non_blocking) ) else: - data_in = tensordict_in[i] - out_td = self._envs[i]._step(data_in) - next_td[i].update_(out_td, keys_to_update=list(self._env_output_keys)) + data_in.append(tensordict_in[i]) + + self._sync_m2w() + + for i, _data_in in enumerate(data_in): + out_td = self._envs[i]._step(_data_in) + next_td[i].update_( + out_td, + keys_to_update=list(self._env_output_keys), + non_blocking=self.non_blocking, + ) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -890,6 +995,7 @@ def select_and_clone(name, tensor): out = out.clear_device_() elif out.device != device: out = out.to(device, non_blocking=self.non_blocking) + self._sync_w2m() return out def __getattr__(self, attr: str) -> Any: @@ -1147,6 +1253,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): "_selected_step_keys": self._selected_step_keys, "has_lazy_inputs": self.has_lazy_inputs, "num_threads": num_sub_threads, + "non_blocking": self.non_blocking, } ) process = proc_fun(target=func, kwargs=kwargs[idx]) @@ -1203,8 +1310,11 @@ def step_and_maybe_reset( # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. + self.shared_tensordict_parent.update_( - tensordict, keys_to_update=self._env_input_keys + tensordict, + keys_to_update=self._env_input_keys, + non_blocking=self.non_blocking, ) next_td_passthrough = tensordict.get("next", None) if next_td_passthrough is not None: @@ -1213,9 +1323,12 @@ def step_and_maybe_reset( # We keep track of which keys are present to let the worker know what # should be passd to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) - self.shared_tensordict_parent.get("next").update_(next_td_passthrough) + self.shared_tensordict_parent.get("next").update_( + next_td_passthrough, non_blocking=self.non_blocking + ) else: next_td_keys = None + self._sync_m2w() for i in range(self.num_workers): self.parent_channels[i].send(("step_and_maybe_reset", next_td_keys)) @@ -1247,6 +1360,7 @@ def step_and_maybe_reset( device=device, filter_empty=True, ) + self._sync_w2m() else: next_td = next_td.clone().clear_device_() tensordict_ = tensordict_.clone().clear_device_() @@ -1264,8 +1378,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. + self.shared_tensordict_parent.update_( - tensordict, keys_to_update=list(self._env_input_keys) + tensordict, + keys_to_update=list(self._env_input_keys), + non_blocking=self.non_blocking, ) next_td_passthrough = tensordict.get("next", None) if next_td_passthrough is not None: @@ -1274,10 +1391,14 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We keep track of which keys are present to let the worker know what # should be passd to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) - self.shared_tensordict_parent.get("next").update_(next_td_passthrough) + self.shared_tensordict_parent.get("next").update_( + next_td_passthrough, non_blocking=self.non_blocking + ) else: next_td_keys = None + self._sync_m2w() + if self.event is not None: self.event.record() self.event.synchronize() @@ -1294,20 +1415,25 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: next_td = self.shared_tensordict_parent.get("next") device = self.device - def select_and_clone(name, tensor): - if name in self._selected_step_keys: - return tensor.clone() + if next_td.device != device and device is not None: + + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.to(device, non_blocking=self.non_blocking) + + else: + + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() out = next_td.named_apply( select_and_clone, nested_keys=True, filter_empty=True, + device=device, ) - if out.device != device: - if device is None: - out.clear_device_() - else: - out = out.to(device, non_blocking=self.non_blocking) + self._sync_w2m() return out @torch.no_grad() @@ -1328,13 +1454,18 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: (self.num_workers,), device=self.device, dtype=torch.bool ) - workers = [] - - for i, channel in enumerate(self.parent_channels): + outs = [] + for i in range(self.num_workers): if tensordict is not None: tensordict_ = tensordict[i] if tensordict_.is_empty(): tensordict_ = None + elif self.device is not None and self.device.type == "mps": + # copy_ fails when moving mps->cpu using copy_ + # in some cases when a view of an mps tensor is used. + # We know the shared tensors are not MPS, so we can + # safely assume that the shared tensors are on cpu + tensordict_ = tensordict_.to("cpu") else: tensordict_ = None if not needs_resetting[i]: @@ -1347,10 +1478,13 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self.shared_tensordicts[i].update_( self.shared_tensordicts[i].get("next"), keys_to_update=list(self._selected_reset_keys), + non_blocking=self.non_blocking, ) if tensordict_ is not None: self.shared_tensordicts[i].update_( - tensordict_, keys_to_update=list(self._selected_reset_keys) + tensordict_, + keys_to_update=list(self._selected_reset_keys), + non_blocking=self.non_blocking, ) continue if tensordict_ is not None: @@ -1359,7 +1493,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # This way we can avoid calling select over all the keys in the shared tensordict def tentative_update(val, other): if other is not None: - val.copy_(other) + val.copy_(other, non_blocking=self.non_blocking) return val self.shared_tensordicts[i].apply_( @@ -1368,11 +1502,14 @@ def tentative_update(val, other): out = ("reset", tdkeys) else: out = ("reset", False) + outs.append((i, out)) - channel.send(out) - workers.append(i) + self._sync_m2w() - for i in workers: + for i, out in outs: + self.parent_channels[i].send(out) + + for i, _ in outs: event = self._events[i] event.wait(self._timeout) event.clear() @@ -1380,21 +1517,25 @@ def tentative_update(val, other): selected_output_keys = self._selected_reset_keys_filt device = self.device - def select_and_clone(name, tensor): - if name in selected_output_keys: - return tensor.clone() + if self.shared_tensordict_parent.device != device and device is not None: + + def select_and_clone(name, tensor): + if name in selected_output_keys: + return tensor.to(device, non_blocking=self.non_blocking) + + else: + + def select_and_clone(name, tensor): + if name in selected_output_keys: + return tensor.clone() out = self.shared_tensordict_parent.named_apply( select_and_clone, nested_keys=True, filter_empty=True, + device=device, ) - - if out.device != device: - if device is None: - out.clear_device_() - else: - out = out.to(device, non_blocking=self.non_blocking) + self._sync_w2m() return out @_check_start @@ -1513,6 +1654,7 @@ def _run_worker_pipe_shared_mem( _selected_input_keys=None, _selected_reset_keys=None, _selected_step_keys=None, + non_blocking: bool = False, has_lazy_inputs: bool = False, verbose: bool = False, num_threads: int | None = None, # for fork start method @@ -1606,6 +1748,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): shared_tensordict.update_( cur_td, keys_to_update=list(_selected_reset_keys), + non_blocking=non_blocking, ) if event is not None: event.record() @@ -1626,7 +1769,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): else: input = root_shared_tensordict next_td = env._step(input) - next_shared_tensordict.update_(next_td) + next_shared_tensordict.update_(next_td, non_blocking=non_blocking) if event is not None: event.record() event.synchronize() @@ -1653,8 +1796,8 @@ def look_for_cuda(tensor, has_cuda=has_cuda): else: input = root_shared_tensordict td, root_next_td = env.step_and_maybe_reset(input) - next_shared_tensordict.update_(td.pop("next")) - root_shared_tensordict.update_(root_next_td) + next_shared_tensordict.update_(td.pop("next"), non_blocking=non_blocking) + root_shared_tensordict.update_(root_next_td, non_blocking=non_blocking) if event is not None: event.record() @@ -1729,5 +1872,13 @@ def _stackable(*tensordicts): return False +def _cuda_sync(device): + return functools.partial(torch.cuda.synchronize, device=device) + + +def _mps_sync(device): + return torch.mps.synchronize + + # Create an alias for possible imports _BatchedEnv = BatchedEnvBase diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index a0a7cb23dfd..fb4a5767597 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key +from tensordict.base import NO_DEFAULT from tensordict.utils import NestedKey from torchrl._utils import _replace_last, implement_for, prod, seed_generator @@ -2560,24 +2561,29 @@ def _rollout_stop_early( env_device, callback, ): + # Get the sync func + if auto_cast_to_device: + sync_func = _get_sync_func(policy_device, env_device) tensordicts = [] for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: tensordict = tensordict.to(policy_device, non_blocking=True) + sync_func() else: tensordict.clear_device_() tensordict = policy(tensordict) if auto_cast_to_device: if env_device is not None: tensordict = tensordict.to(env_device, non_blocking=True) + sync_func() else: tensordict.clear_device_() tensordict = self.step(tensordict) tensordicts.append(tensordict.clone(False)) if i == max_steps - 1: - # we don't truncated as one could potentially continue the run + # we don't truncate as one could potentially continue the run break tensordict = self._step_mdp(tensordict) @@ -2606,18 +2612,22 @@ def _rollout_nonstop( env_device, callback, ): + if auto_cast_to_device: + sync_func = _get_sync_func(policy_device, env_device) tensordicts = [] tensordict_ = tensordict for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: tensordict_ = tensordict_.to(policy_device, non_blocking=True) + sync_func() else: tensordict_.clear_device_() tensordict_ = policy(tensordict_) if auto_cast_to_device: if env_device is not None: tensordict_ = tensordict_.to(env_device, non_blocking=True) + sync_func() else: tensordict_.clear_device_() if i == max_steps - 1: @@ -2626,7 +2636,7 @@ def _rollout_nonstop( tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) tensordicts.append(tensordict) if i == max_steps - 1: - # we don't truncated as one could potentially continue the run + # we don't truncate as one could potentially continue the run break if callback is not None: callback(self, tensordict) @@ -2711,7 +2721,11 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: any_done = done.any() if any_done: tensordict._set_str( - "_reset", done.clone(), validated=True, inplace=False + "_reset", + done.clone(), + validated=True, + inplace=False, + non_blocking=False, ) else: any_done = _terminated_or_truncated( @@ -2903,12 +2917,20 @@ def __init__( self, *args, dtype: Optional[np.dtype] = None, - device: DEVICE_TYPING = None, + device: DEVICE_TYPING = NO_DEFAULT, batch_size: Optional[torch.Size] = None, allow_done_after_reset: bool = False, **kwargs, ): - if device is None: + if device is NO_DEFAULT: + warnings.warn( + "Your wrapper was not given a device. Currently, this " + "value will default to 'cpu'. From v0.5 it will " + "default to `None`. With a device of None, no device casting " + "is performed and the resulting tensordicts are deviceless. " + "Please set your device accordingly.", + category=DeprecationWarning, + ) device = torch.device("cpu") super().__init__( device=device, @@ -2936,6 +2958,22 @@ def __init__( self.is_closed = False self._init_env() # runs all the steps to have a ready-to-use env + def _sync_device(self): + sync_func = self.__dict__.get("_sync_device_val", None) + if sync_func is None: + device = self.device + if device.type != "cuda": + if torch.cuda.is_available(): + self._sync_device_val = torch.cuda.synchronize + elif torch.backends.mps.is_available(): + self._sync_device_val = torch.cuda.synchronize + elif device.type == "cpu": + self._sync_device_val = _do_nothing + else: + self._sync_device_val = _do_nothing + return self._sync_device + return sync_func + @abc.abstractmethod def _check_kwargs(self, kwargs: Dict): raise NotImplementedError @@ -3017,3 +3055,24 @@ def make_tensordict( tensordict.set("action", env.action_spec.rand(), inplace=False) tensordict = env.step(tensordict) return tensordict.zero_() + + +def _get_sync_func(policy_device, env_device): + if torch.cuda.is_available(): + # Look for a specific device + if policy_device is not None and policy_device.type == "cuda": + if env_device is None or env_device.type == "cuda": + return torch.cuda.synchronize + return functools.partial(torch.cuda.synchronize, device=policy_device) + if env_device is not None and env_device.type == "cuda": + if policy_device is None: + return torch.cuda.synchronize + return functools.partial(torch.cuda.synchronize, device=env_device) + return torch.cuda.synchronize + if torch.backends.mps.is_available(): + return torch.mps.synchronize + return _do_nothing + + +def _do_nothing(): + return diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 3c52f941bdb..6c1dbbd0389 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -329,6 +329,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: ) if self.device is not None: tensordict_out = tensordict_out.to(self.device, non_blocking=True) + self._sync_device() if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): @@ -372,7 +373,9 @@ def _reset( for key, item in self.observation_spec.items(True, True): if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() - tensordict_out = tensordict_out.to(self.device, non_blocking=True) + if self.device is not None: + tensordict_out = tensordict_out.to(self.device, non_blocking=True) + self._sync_device() return tensordict_out @abc.abstractmethod diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 5c5a387f762..d8043cb9ef7 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -769,8 +769,9 @@ def _get_batch_size(self, env): @implement_for("gymnasium") # gymnasium wants the unwrapped env def _get_batch_size(self, env): # noqa: F811 - if hasattr(env, "num_envs"): - batch_size = torch.Size([env.unwrapped.num_envs, *self.batch_size]) + env_unwrapped = env.unwrapped + if hasattr(env_unwrapped, "num_envs"): + batch_size = torch.Size([env_unwrapped.num_envs, *self.batch_size]) else: batch_size = self.batch_size return batch_size @@ -929,6 +930,18 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811 self._seed_calls_reset = False self._env.seed(seed=seed) + @implement_for("gym") + def _reward_space(self, env): + if hasattr(env, "reward_space") and env.reward_space is not None: + return env.reward_space + + @implement_for("gymnasium") + def _reward_space(self, env): # noqa: F811 + env = env.unwrapped + if hasattr(env, "reward_space") and env.reward_space is not None: + rs = env.reward_space + return rs + def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 action_spec = _gym_to_torchrl_spec_transform( env.action_space, @@ -952,9 +965,10 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 elif observation_spec.shape[: len(self.batch_size)] != self.batch_size: observation_spec.shape = self.batch_size - if hasattr(env, "reward_space") and env.reward_space is not None: + reward_space = self._reward_space(env) + if reward_space is not None: reward_spec = _gym_to_torchrl_spec_transform( - env.reward_space, + reward_space, device=self.device, categorical_action_encoding=self._categorical_action_encoding, ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b6ebead77aa..9c2ec5a052f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -56,7 +56,7 @@ TensorSpec, UnboundedContinuousTensorSpec, ) -from torchrl.envs.common import _EnvPostInit, EnvBase, make_tensordict +from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict from torchrl.envs.transforms import functional as F from torchrl.envs.transforms.utils import ( _get_reset, @@ -3765,11 +3765,20 @@ def __init__( device, orig_device=None, ): - self.device = torch.device(device) + device = self.device = torch.device(device) self.orig_device = ( torch.device(orig_device) if orig_device is not None else orig_device ) super().__init__() + if device.type != "cuda": + if torch.cuda.is_available(): + self._sync_device = torch.cuda.synchronize + elif torch.backends.mps.is_available(): + self._sync_device = torch.cuda.synchronize + elif device.type == "cpu": + self._sync_device = _do_nothing + else: + self._sync_device = _do_nothing def set_container(self, container: Union[Transform, EnvBase]) -> None: if self.orig_device is None: @@ -3786,10 +3795,14 @@ def set_container(self, container: Union[Transform, EnvBase]) -> None: @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=True) + result = tensordict.to(self.device, non_blocking=True) + self._sync_device() + return result def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - return tensordict.to(self.device, non_blocking=True) + result = tensordict.to(self.device, non_blocking=True) + self._sync_device() + return result def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase @@ -3799,11 +3812,30 @@ def _reset( def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: parent = self.parent - if parent is None: - if self.orig_device is None: - return tensordict - return tensordict.to(self.orig_device, non_blocking=True) - return tensordict.to(parent.device, non_blocking=True) + orig_device = self.orig_device if parent is None else parent.device + if orig_device is not None: + result = tensordict.to(orig_device, non_blocking=True) + self._sync_orig_device() + return result + return tensordict + + @property + def _sync_orig_device(self): + sync_func = self.__dict__.get("_sync_orig_device_val", None) + if sync_func is None: + parent = self.parent + device = self.orig_device if parent is None else parent.device + if device.type != "cuda": + if torch.cuda.is_available(): + self._sync_orig_device_val = torch.cuda.synchronize + elif torch.backends.mps.is_available(): + self._sync_orig_device_val = torch.cuda.synchronize + elif device.type == "cpu": + self._sync_orig_device_val = _do_nothing + else: + self._sync_orig_device_val = _do_nothing + return self._sync_orig_device + return sync_func def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: return input_spec.to(self.device) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 6094ccf9f77..e0fea1751ed 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -165,7 +165,13 @@ def validate(self, tensordict): + [unravel_key(("next", key)) for key in self.done_keys] + [unravel_key(("next", key)) for key in self.reward_keys] ) - actual = set(tensordict.keys(True, True)) + + def _is_reset(key: NestedKey): + if isinstance(key, str): + return key == "_reset" + return key[-1] == "_reset" + + actual = {key for key in tensordict.keys(True, True) if not _is_reset(key)} expected = set(expected) self.validated = expected.intersection(actual) == expected if not self.validated: @@ -212,7 +218,9 @@ def _grab_and_place( ) else: val = cls._grab_and_place(subdict, val, val_out) - data_out._set_str(key, val, validated=True, inplace=False) + data_out._set_str( + key, val, validated=True, inplace=False, non_blocking=False + ) return data_out @classmethod @@ -464,7 +472,9 @@ def _set_single_key( new_val = dest._get_str(k, None) if new_val is None: new_val = val.empty() - dest._set_str(k, new_val, inplace=False, validated=True) + dest._set_str( + k, new_val, inplace=False, validated=True, non_blocking=False + ) source = val dest = new_val else: @@ -472,7 +482,7 @@ def _set_single_key( val = val.to(device, non_blocking=True) elif clone: val = val.clone() - dest._set_str(k, val, inplace=False, validated=True) + dest._set_str(k, val, inplace=False, validated=True, non_blocking=False) # This is a temporary solution to understand if a key is heterogeneous # while not having performance impact when the exception is not raised except RuntimeError as err: @@ -504,12 +514,16 @@ def _set(source, dest, key, total_key, excluded): ) if non_empty_local: # dest.set(key, new_val) - dest._set_str(key, new_val, inplace=False, validated=True) + dest._set_str( + key, new_val, inplace=False, validated=True, non_blocking=False + ) non_empty = non_empty_local else: non_empty = True # dest.set(key, val) - dest._set_str(key, val, inplace=False, validated=True) + dest._set_str( + key, val, inplace=False, validated=True, non_blocking=False + ) # This is a temporary solution to understand if a key is heterogeneous # while not having performance impact when the exception is not raised except RuntimeError as err: