From 2c485dd69ee2ee49c18482a4da1ac1784e150947 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 28 Mar 2024 17:40:56 +0000 Subject: [PATCH] [Feature] Fine grained DeviceCastTransform (#2041) --- test/test_transforms.py | 234 +++++++++++++++++++++++++- torchrl/data/tensor_specs.py | 7 +- torchrl/envs/transforms/transforms.py | 148 +++++++++++++--- 3 files changed, 365 insertions(+), 24 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index a4e597a89e2..8a11e849e30 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -9331,7 +9331,239 @@ def test_transform_inverse(self): return -class TestDeviceCastTransform(TransformBase): +class TestDeviceCastTransformPart(TransformBase): + @pytest.mark.parametrize("in_keys", ["observation"]) + @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) + @pytest.mark.parametrize("in_keys_inv", ["action"]) + @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) + def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv): + env = ContinuousActionVecMockEnv(device="cpu:0") + env = TransformedEnv( + env, + DeviceCastTransform( + "cpu:1", + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ), + ) + assert env.device is None + check_env_specs(env) + + @pytest.mark.parametrize("in_keys", ["observation"]) + @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) + @pytest.mark.parametrize("in_keys_inv", ["action"]) + @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) + def test_serial_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv): + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(device="cpu:0"), + DeviceCastTransform( + "cpu:1", + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ), + ) + + env = SerialEnv(2, make_env) + assert env.device is None + check_env_specs(env) + + @pytest.mark.parametrize("in_keys", ["observation"]) + @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) + @pytest.mark.parametrize("in_keys_inv", ["action"]) + @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) + def test_parallel_trans_env_check( + self, in_keys, out_keys, in_keys_inv, out_keys_inv + ): + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(device="cpu:0"), + DeviceCastTransform( + "cpu:1", + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ), + ) + + env = ParallelEnv( + 2, + make_env, + mp_start_method="fork" if not torch.cuda.is_available() else "spawn", + ) + assert env.device is None + try: + check_env_specs(env) + finally: + env.close() + + @pytest.mark.parametrize("in_keys", ["observation"]) + @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) + @pytest.mark.parametrize("in_keys_inv", ["action"]) + @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) + def test_trans_serial_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv): + def make_env(): + return ContinuousActionVecMockEnv(device="cpu:0") + + env = TransformedEnv( + SerialEnv(2, make_env), + DeviceCastTransform( + "cpu:1", + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ), + ) + assert env.device is None + check_env_specs(env) + + @pytest.mark.parametrize("in_keys", ["observation"]) + @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) + @pytest.mark.parametrize("in_keys_inv", ["action"]) + @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) + def test_trans_parallel_env_check( + self, in_keys, out_keys, in_keys_inv, out_keys_inv + ): + def make_env(): + return ContinuousActionVecMockEnv(device="cpu:0") + + env = TransformedEnv( + ParallelEnv( + 2, + make_env, + mp_start_method="fork" if not torch.cuda.is_available() else "spawn", + ), + DeviceCastTransform( + "cpu:1", + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ), + ) + assert env.device is None + try: + check_env_specs(env) + finally: + env.close() + + def test_transform_no_env(self): + t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"]) + td = TensorDict({"a": torch.randn((), device="cpu:0")}, [], device="cpu:0") + tdt = t._call(td) + assert tdt.device is None + + @pytest.mark.parametrize("in_keys", ["observation"]) + @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) + @pytest.mark.parametrize("in_keys_inv", ["action"]) + @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) + def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv): + env = ContinuousActionVecMockEnv(device="cpu:0") + env = TransformedEnv( + env, + DeviceCastTransform( + "cpu:1", + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ), + ) + assert env.device is None + assert env.transform.device == torch.device("cpu:1") + assert env.transform.orig_device == torch.device("cpu:0") + + def test_transform_compose(self): + t = Compose( + DeviceCastTransform( + "cpu:1", + "cpu:0", + in_keys=["a"], + out_keys=["b"], + in_keys_inv=["c"], + out_keys_inv=["d"], + ) + ) + + td = TensorDict( + { + "a": torch.randn((), device="cpu:0"), + "c": torch.randn((), device="cpu:1"), + }, + [], + device="cpu:0", + ) + tdt = t._call(td) + tdit = t._inv_call(td) + + assert tdt.device is None + assert tdit.device is None + + def test_transform_model(self): + t = nn.Sequential( + Compose( + DeviceCastTransform( + "cpu:1", + "cpu:0", + in_keys=["a"], + out_keys=["b"], + in_keys_inv=["c"], + out_keys_inv=["d"], + ) + ) + ) + td = TensorDict( + { + "a": torch.randn((), device="cpu:0"), + "c": torch.randn((), device="cpu:1"), + }, + [], + device="cpu:0", + ) + tdt = t(td) + + assert tdt.device is None + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + @pytest.mark.parametrize("storage", [LazyTensorStorage]) + def test_transform_rb(self, rbclass, storage): + # we don't test casting to cuda on Memmap tensor storage since it's discouraged + t = Compose( + DeviceCastTransform( + "cpu:1", + "cpu:0", + in_keys=["a"], + out_keys=["b"], + in_keys_inv=["c"], + out_keys_inv=["d"], + ) + ) + rb = rbclass(storage=storage(max_size=20, device="auto")) + rb.append_transform(t) + td = TensorDict( + { + "a": torch.randn((), device="cpu:0"), + "c": torch.randn((), device="cpu:1"), + }, + [], + device="cpu:0", + ) + rb.add(td) + assert rb._storage._storage.device is None + assert rb.sample(4).device is None + + def test_transform_inverse(self): + # Tested before + return + + +class TestDeviceCastTransformWhole(TransformBase): def test_single_trans_env_check(self): env = ContinuousActionVecMockEnv(device="cpu:0") env = TransformedEnv(env, DeviceCastTransform("cpu:1")) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 9f50e97d9fe..71598938eab 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -542,7 +542,7 @@ def decorator(func): def clear_device_(self): """A no-op for all leaf specs (which must have a device).""" - pass + return self def encode( self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False @@ -866,6 +866,7 @@ def clear_device_(self): """Clears the device of the CompositeSpec.""" for spec in self._specs: spec.clear_device_() + return self def __getitem__(self, item): is_key = isinstance(item, str) or ( @@ -3594,8 +3595,10 @@ def device(self, device: DEVICE_TYPING): def clear_device_(self): """Clears the device of the CompositeSpec.""" - for spec in self._specs: + self._device = None + for spec in self._specs.values(): spec.clear_device_() + return self def __getitem__(self, idx): """Indexes the current CompositeSpec based on the provided index.""" diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index c3ad6dbb764..14a8e9e1a02 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6,6 +6,7 @@ from __future__ import annotations import collections +import functools import importlib.util import multiprocessing as mp import warnings @@ -3779,12 +3780,27 @@ def __init__( self, device, orig_device=None, + *, + in_keys=None, + out_keys=None, + in_keys_inv=None, + out_keys_inv=None, ): device = self.device = torch.device(device) self.orig_device = ( torch.device(orig_device) if orig_device is not None else orig_device ) - super().__init__() + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ) + self._map_env_device = not self.in_keys and not self.in_keys_inv + + self._rename_keys = self.in_keys != self.out_keys + self._rename_keys_inv = self.in_keys_inv != self.out_keys_inv + if device.type != "cuda": if torch.cuda.is_available(): self._sync_device = torch.cuda.synchronize @@ -3808,16 +3824,44 @@ def set_container(self, container: Union[Transform, EnvBase]) -> None: self.orig_device = device return super().set_container(container) + def _to(self, name, tensor): + if name in self.in_keys: + return tensor.to(self.device, non_blocking=True) + return tensor + + def _to_inv(self, name, tensor, device): + if name in self.in_keys_inv: + return tensor.to(device, non_blocking=True) + return tensor + @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - result = tensordict.to(self.device, non_blocking=True) + if self._map_env_device: + result = tensordict.to(self.device, non_blocking=True) + self._sync_device() + return result + tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None) + if self._rename_keys: + for in_key, out_key in zip(self.in_keys, self.out_keys): + if out_key != in_key: + tensordict_t.rename_key_(in_key, out_key) + tensordict_t.set(in_key, tensordict.get(in_key)) self._sync_device() - return result + return tensordict_t def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - result = tensordict.to(self.device, non_blocking=True) + if self._map_env_device: + result = tensordict.to(self.device, non_blocking=True) + self._sync_device() + return result + tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None) + if self._rename_keys: + for in_key, out_key in zip(self.in_keys, self.out_keys): + if out_key != in_key: + tensordict_t.rename_key_(in_key, out_key) + tensordict_t.set(in_key, tensordict.get(in_key)) self._sync_device() - return result + return tensordict_t def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase @@ -3827,12 +3871,25 @@ def _reset( def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: parent = self.parent - orig_device = self.orig_device if parent is None else parent.device - if orig_device is not None: - result = tensordict.to(orig_device, non_blocking=True) + device = self.orig_device if parent is None else parent.device + if device is None: + return tensordict + if self._map_env_device: + result = tensordict.to(device, non_blocking=True) self._sync_orig_device() return result - return tensordict + tensordict_t = tensordict.named_apply( + functools.partial(self._to_inv, device=device), + nested_keys=True, + device=None, + ) + if self._rename_keys_inv: + for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + if out_key != in_key: + tensordict_t.rename_key_(in_key, out_key) + tensordict_t.set(in_key, tensordict.get(in_key)) + self._sync_orig_device() + return tensordict_t @property def _sync_orig_device(self): @@ -3852,27 +3909,76 @@ def _sync_orig_device(self): return self._sync_orig_device return sync_func - def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: - return input_spec.to(self.device) + def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: + if self._map_env_device: + return input_spec.to(self.device) + else: + return super().transform_input_spec(input_spec) - def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - return reward_spec.to(self.device) + def transform_action_spec(self, full_action_spec: CompositeSpec) -> CompositeSpec: + full_action_spec = full_action_spec.clear_device_() + for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + if in_key not in full_action_spec.keys(True, True): + continue + full_action_spec[out_key] = full_action_spec[in_key].to(self.device) + return full_action_spec - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - return observation_spec.to(self.device) + def transform_state_spec(self, full_state_spec: CompositeSpec) -> CompositeSpec: + full_state_spec = full_state_spec.clear_device_() + for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + if in_key not in full_state_spec.keys(True, True): + continue + full_state_spec[out_key] = full_state_spec[in_key].to(self.device) + return full_state_spec def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: - return output_spec.to(self.device) + if self._map_env_device: + return output_spec.to(self.device) + else: + return super().transform_output_spec(output_spec) - def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: - return done_spec.to(self.device) + def transform_observation_spec( + self, observation_spec: CompositeSpec + ) -> CompositeSpec: + observation_spec = observation_spec.clear_device_() + for in_key, out_key in zip(self.in_keys, self.out_keys): + if in_key not in observation_spec.keys(True, True): + continue + observation_spec[out_key] = observation_spec[in_key].to(self.device) + return observation_spec + + def transform_done_spec(self, full_done_spec: CompositeSpec) -> CompositeSpec: + full_done_spec = full_done_spec.clear_device_() + for in_key, out_key in zip(self.in_keys, self.out_keys): + if in_key not in full_done_spec.keys(True, True): + continue + full_done_spec[out_key] = full_done_spec[in_key].to(self.device) + return full_done_spec + + def transform_reward_spec(self, full_reward_spec: CompositeSpec) -> CompositeSpec: + full_reward_spec = full_reward_spec.clear_device_() + for in_key, out_key in zip(self.in_keys, self.out_keys): + if in_key not in full_reward_spec.keys(True, True): + continue + full_reward_spec[out_key] = full_reward_spec[in_key].to(self.device) + return full_reward_spec def transform_env_device(self, device): - return self.device + if self._map_env_device: + return self.device + # In all other cases the device is not defined + return None def __repr__(self) -> str: - s = f"{self.__class__.__name__}(device={self.device}, orig_device={self.orig_device})" - return s + if self._map_env_device: + return f"{self.__class__.__name__}(device={self.device}, orig_device={self.orig_device})" + device = indent(4 * " ", f"device={self.device}") + orig_device = indent(4 * " ", f"orig_device={self.orig_device}") + in_keys = indent(4 * " ", f"in_keys={self.in_keys}") + out_keys = indent(4 * " ", f"out_keys={self.out_keys}") + in_keys_inv = indent(4 * " ", f"in_keys_inv={self.in_keys_inv}") + out_keys_inv = indent(4 * " ", f"out_keys_inv={self.out_keys_inv}") + return f"{self.__class__.__name__}(\n{device},\n{orig_device},\n{in_keys},\n{out_keys},\n{in_keys_inv},\n{out_keys_inv})" class CatTensors(Transform):