From a87094d287ac2e1063b6f2d5efd780015c596c68 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Mar 2024 08:52:32 +0000 Subject: [PATCH] [Feature] RB MultiStep transform (#2008) --- docs/source/reference/data.rst | 8 + test/test_cost.py | 44 ++-- test/test_postprocs.py | 47 ++-- test/test_transforms.py | 93 ++++++++ torchrl/data/postprocs/postprocs.py | 201 +++++++++++----- torchrl/data/replay_buffers/replay_buffers.py | 8 + torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/rb_transforms.py | 217 ++++++++++++++++++ torchrl/envs/transforms/transforms.py | 4 +- 10 files changed, 504 insertions(+), 120 deletions(-) create mode 100644 torchrl/envs/transforms/rb_transforms.py diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 6638a3b1513..49708e5e404 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -823,3 +823,11 @@ Utils consolidate_spec check_no_exclusive_keys contains_lazy_spec + +.. currentmodule:: torchrl.envs.transforms.rb_transforms + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + MultiStepTransform diff --git a/test/test_cost.py b/test/test_cost.py index d9e28046132..581e9247772 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -548,7 +548,7 @@ def test_dqn_state_dict(self, delay_value, device, action_spec_type): loss_fn2 = DQNLoss(actor, loss_function="l2", delay_value=delay_value) loss_fn2.load_state_dict(sd) - @pytest.mark.parametrize("n", range(4)) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) @@ -580,7 +580,7 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): with torch.no_grad(): loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*td.keys(True, True))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss")] @@ -1126,7 +1126,7 @@ def test_qmixer_state_dict(self, delay_value, device, action_spec_type): loss_fn2 = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) loss_fn2.load_state_dict(sd) - @pytest.mark.parametrize("n", range(4)) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) @@ -1159,7 +1159,7 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) with torch.no_grad(): loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*td.keys(True, True))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss")] @@ -1803,7 +1803,7 @@ def test_ddpg_separate_losses( raise NotImplementedError(k) loss_fn.zero_grad() - @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)]) def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): @@ -1834,7 +1834,7 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): with torch.no_grad(): loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss_")] @@ -2436,7 +2436,7 @@ def test_td3_separate_losses( loss_fn.zero_grad() @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") - @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("delay_actor,delay_qvalue", [(False, False), (True, True)]) @pytest.mark.parametrize("policy_noise", [0.1, 1.0]) @@ -2482,7 +2482,7 @@ def test_td3_batcher( np.random.seed(0) loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss_")] @@ -3231,7 +3231,7 @@ def test_sac_separate_losses( raise NotImplementedError(k) loss_fn.zero_grad() - @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_value", (True, False)) @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @@ -3295,7 +3295,7 @@ def test_sac_batcher( torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss_")] @@ -3930,7 +3930,7 @@ def test_discrete_sac_state_dict( ) loss_fn2.load_state_dict(sd) - @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [2]) @pytest.mark.parametrize("device", get_default_devices()) @@ -3986,7 +3986,7 @@ def test_discrete_sac_batcher( torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss_")] @@ -4875,7 +4875,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): # TODO: find a way to compare the losses: problem is that we sample actions either sequentially or in batch, # so setting seed has little impact - @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @@ -4918,7 +4918,7 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss_")] @@ -5487,7 +5487,7 @@ def test_cql_state_dict( ) loss_fn2.load_state_dict(sd) - @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("max_q_backup", [True, False]) @@ -5542,7 +5542,7 @@ def test_cql_batcher( torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss_")] @@ -5848,7 +5848,7 @@ def test_dcql_state_dict(self, delay_value, device, action_spec_type): loss_fn2 = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) loss_fn2.load_state_dict(sd) - @pytest.mark.parametrize("n", range(4)) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_value", (False, True)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) @@ -5879,7 +5879,7 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) with torch.no_grad(): loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*td.keys(True, True))) _loss = sum([item for key, item in loss.items() if key.startswith("loss_")]) _loss_ms = sum( @@ -9364,7 +9364,7 @@ def test_iql_separate_losses(self, separate_losses): raise NotImplementedError(k) loss_fn.zero_grad() - @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0]) @pytest.mark.parametrize("expectile", [0.1, 0.5, 1.0]) @@ -9415,7 +9415,7 @@ def test_iql_batcher( torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss_")] @@ -10176,7 +10176,7 @@ def test_discrete_iql_separate_losses(self, separate_losses): raise NotImplementedError(k) loss_fn.zero_grad() - @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0]) @pytest.mark.parametrize("expectile", [0.1, 0.5]) @@ -10227,7 +10227,7 @@ def test_discrete_iql_batcher( torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) - if n == 0: + if n == 1: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum( [item for name, item in loss.items() if name.startswith("loss_")] diff --git a/test/test_postprocs.py b/test/test_postprocs.py index c3cba371167..09a1edb518d 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -13,7 +13,7 @@ from torchrl.data.postprocs.postprocs import MultiStep -@pytest.mark.parametrize("n", range(13)) +@pytest.mark.parametrize("n", range(1, 14)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("key", ["observation", "pixels", "observation_whatever"]) def test_multistep(n, key, device, T=11): @@ -58,7 +58,7 @@ def test_multistep(n, key, device, T=11): assert ms_tensordict.get("done").max() == 1 - if n == 0: + if n == 1: assert_allclose_td( tensordict, ms_tensordict.select(*list(tensordict.keys(True, True))) ) @@ -76,12 +76,10 @@ def test_multistep(n, key, device, T=11): ) # check that next obs is properly replaced, or that it is terminated - next_obs = ms_tensordict.get(key)[:, (1 + ms.n_steps) :] - true_next_obs = ms_tensordict.get(("next", key))[:, : -(1 + ms.n_steps)] + next_obs = ms_tensordict.get(key)[:, (ms.n_steps) :] + true_next_obs = ms_tensordict.get(("next", key))[:, : -(ms.n_steps)] terminated = ~ms_tensordict.get("nonterminal") - assert ( - (next_obs == true_next_obs).all(-1) | terminated[:, (1 + ms.n_steps) :] - ).all() + assert ((next_obs == true_next_obs).all(-1) | terminated[:, (ms.n_steps) :]).all() # test gamma computation torch.testing.assert_close( @@ -89,7 +87,7 @@ def test_multistep(n, key, device, T=11): ) # test reward - if n > 0: + if n > 1: assert ( ms_tensordict.get(("next", "reward")) != ms_tensordict.get(("next", "original_reward")) @@ -105,36 +103,17 @@ def test_multistep(n, key, device, T=11): @pytest.mark.parametrize( "batch_size", [ - [ - 4, - ], + [4], [], - [ - 1, - ], + [1], [2, 3], ], ) -@pytest.mark.parametrize( - "T", - [ - 10, - 1, - 2, - ], -) -@pytest.mark.parametrize( - "obs_dim", - [ - [ - 1, - ], - [], - ], -) +@pytest.mark.parametrize("T", [10, 1, 2]) +@pytest.mark.parametrize("obs_dim", [[1], []]) @pytest.mark.parametrize("unsq_reward", [True, False]) @pytest.mark.parametrize("last_done", [True, False]) -@pytest.mark.parametrize("n_steps", [3, 1, 0]) +@pytest.mark.parametrize("n_steps", [4, 2, 1]) def test_mutistep_cattrajs( batch_size, T, obs_dim, unsq_reward, last_done, device, n_steps ): @@ -166,7 +145,7 @@ def test_mutistep_cattrajs( ) ms = MultiStep(0.98, n_steps) tdm = ms(td) - if n_steps == 0: + if n_steps == 1: # n_steps = 0 has no effect for k in td["next"].keys(): assert (tdm["next", k] == td["next", k]).all() @@ -179,7 +158,7 @@ def test_mutistep_cattrajs( if unsq_reward: done = done.squeeze(-1) for t in range(T): - idx = t + n_steps + idx = t + n_steps - 1 while (done[..., t:idx].any() and idx > t) or idx > done.shape[-1] - 1: idx = idx - 1 next_obs.append(obs[..., idx]) diff --git a/test/test_transforms.py b/test/test_transforms.py index 27f696a1dfc..24f8e06afd5 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -32,6 +32,7 @@ from mocking_classes import ( ContinuousActionVecMockEnv, CountingBatchedEnv, + CountingEnv, CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, @@ -80,6 +81,7 @@ GrayScale, gSDENoise, InitTracker, + MultiStepTransform, NoopResetEnv, ObservationNorm, ParallelEnv, @@ -10191,6 +10193,97 @@ def test_transform_inverse(self): assert ("state", "sub") in td2.keys(True) +class TestMultiStepTransform: + def test_multistep_transform(self): + env = TransformedEnv( + SerialEnv( + 2, [lambda: CountingEnv(max_steps=4), lambda: CountingEnv(max_steps=10)] + ), + StepCounter(), + ) + + env.set_seed(0) + torch.manual_seed(0) + + t = MultiStepTransform(3, 0.98) + + outs_2 = [] + td = env.reset() + for _ in range(1): + rollout = env.rollout( + 250, auto_reset=False, tensordict=td, break_when_any_done=False + ) + out = t._inv_call(rollout) + td = rollout[..., -1] + outs_2.append(out) + # This will break if we don't have the appropriate number of frames + outs_2 = torch.cat(outs_2, -1).split([47, 50, 50, 50, 50], -1) + + t = MultiStepTransform(3, 0.98) + + env.set_seed(0) + torch.manual_seed(0) + + outs = [] + td = env.reset() + for i in range(5): + rollout = env.rollout( + 50, auto_reset=False, tensordict=td, break_when_any_done=False + ) + out = t._inv_call(rollout) + # tests that the data is insensitive to the collection schedule + assert_allclose_td(out, outs_2[i]) + td = rollout[..., -1]["next"] + outs.append(out) + + outs = torch.cat(outs, -1) + + # Test with a very tiny window and across the whole collection + t = MultiStepTransform(3, 0.98) + + env.set_seed(0) + torch.manual_seed(0) + + outs_3 = [] + td = env.reset() + for _ in range(125): + rollout = env.rollout( + 2, auto_reset=False, tensordict=td, break_when_any_done=False + ) + out = t._inv_call(rollout) + td = rollout[..., -1]["next"] + if out is not None: + outs_3.append(out) + + outs_3 = torch.cat(outs_3, -1) + + assert_allclose_td(outs, outs_3) + + def test_multistep_transform_changes(self): + data = TensorDict( + { + "steps": torch.arange(100), + "next": { + "steps": torch.arange(1, 101), + "reward": torch.ones(100, 1), + "done": torch.zeros(100, 1, dtype=torch.bool), + "terminated": torch.zeros(100, 1, dtype=torch.bool), + "truncated": torch.zeros(100, 1, dtype=torch.bool), + }, + }, + batch_size=[100], + ) + data_splits = data.split(10) + t = MultiStepTransform(3, 0.98) + rb = ReplayBuffer(storage=LazyTensorStorage(100), transform=t) + for data in data_splits: + rb.extend(data) + t.n_steps = t.n_steps + 1 + assert (rb[:]["steps"] == torch.arange(len(rb))).all() + assert rb[:]["next", "steps"][-1] == data["steps"][-1] + assert t._buffer["steps"][-1] == data["steps"][-1] + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index d7b2db3f15a..4d15ba9a78d 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -93,6 +93,48 @@ class MultiStep(nn.Module): gamma (float): Discount factor for return computation n_steps (integer): maximum look-ahead steps. + .. note:: This class is meant to be used within a ``DataCollector``. + It will only treat the data passed to it at the end of a collection, + and ignore data preceding that collection or coming in the next batch. + As such, results on the last steps of the batch may likely be biased + by the early truncation of the trajectory. + To mitigate this effect, please use :class:`~torchrl.envs.transforms.MultiStepTransform` + within the replay buffer instead. + + Examples: + >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.data.postprocs import MultiStep + >>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter + >>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) + >>> env.set_seed(0) + >>> collector = SyncDataCollector(env, policy=RandomPolicy(env.action_spec), + ... frames_per_batch=10, total_frames=2000, postproc=MultiStep(n_steps=4, gamma=0.99)) + >>> for data in collector: + ... break + >>> print(data["step_count"]) + tensor([[0], + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [8], + [9]]) + >>> # the next step count is shifted by 3 steps in the future + >>> print(data["next", "step_count"]) + tensor([[ 5], + [ 6], + [ 7], + [ 8], + [ 9], + [10], + [10], + [10], + [10], + [10]]) + """ def __init__( @@ -101,7 +143,7 @@ def __init__( n_steps: int, ): super().__init__() - if n_steps < 0: + if n_steps <= 0: raise ValueError("n_steps must be a non-negative integer.") if not (gamma > 0 and gamma <= 1): raise ValueError(f"got out-of-bounds gamma decay: gamma={gamma}") @@ -115,6 +157,10 @@ def __init__( dtype=torch.float, ).reshape(1, 1, -1), ) + self.done_key = "done" + self.done_keys = ("done", "terminated", "truncated") + self.reward_keys = ("reward",) + self.mask_key = ("collector", "mask") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Re-writes a tensordict following the multi-step transform. @@ -151,68 +197,99 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: in-place transformation of the input tensordict. """ - tensordict = tensordict.clone(False) - done = tensordict.get(("next", "done")) - truncated = tensordict.get( - ("next", "truncated"), torch.zeros((), dtype=done.dtype, device=done.device) + return _multi_step_func( + tensordict, + done_key=self.done_key, + done_keys=self.done_keys, + reward_keys=self.reward_keys, + mask_key=self.mask_key, + n_steps=self.n_steps, + gamma=self.gamma, ) - done = done | truncated - - # we'll be using the done states to index the tensordict. - # if the shapes don't match we're in trouble. - ndim = tensordict.ndim - if done.shape != tensordict.shape: - if done.shape[-1] == 1 and done.shape[:-1] == tensordict.shape: - done = done.squeeze(-1) - else: - try: - # let's try to reshape the tensordict - tensordict.batch_size = done.shape - tensordict = tensordict.apply( - lambda x: x.transpose(ndim - 1, tensordict.ndim - 1), - batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape, - ) - done = tensordict.get(("next", "done")) - except Exception as err: - raise RuntimeError( - "tensordict shape must be compatible with the done's shape " - "(trailing singleton dimension excluded)." - ) from err - - mask = tensordict.get(("collector", "mask"), None) - reward = tensordict.get(("next", "reward")) - *batch, T = tensordict.batch_size + + +def _multi_step_func( + tensordict, + *, + done_key, + done_keys, + reward_keys, + mask_key, + n_steps, + gamma, +): + # in accordance with common understanding of what n_steps should be + n_steps = n_steps - 1 + tensordict = tensordict.clone(False) + done = tensordict.get(("next", done_key)) + + # we'll be using the done states to index the tensordict. + # if the shapes don't match we're in trouble. + ndim = tensordict.ndim + if done.shape != tensordict.shape: + if done.shape[-1] == 1 and done.shape[:-1] == tensordict.shape: + done = done.squeeze(-1) + else: + try: + # let's try to reshape the tensordict + tensordict.batch_size = done.shape + tensordict = tensordict.apply( + lambda x: x.transpose(ndim - 1, tensordict.ndim - 1), + batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape, + ) + done = tensordict.get(("next", done_key)) + except Exception as err: + raise RuntimeError( + "tensordict shape must be compatible with the done's shape " + "(trailing singleton dimension excluded)." + ) from err + + if mask_key is not None: + mask = tensordict.get(mask_key, None) + else: + mask = None + + *batch, T = tensordict.batch_size + + summed_rewards = [] + for reward_key in reward_keys: + reward = tensordict.get(("next", reward_key)) # sum rewards - summed_rewards, time_to_obs = _get_reward( - self.gamma, reward, done, self.n_steps - ) - idx_to_gather = torch.arange( - T, device=time_to_obs.device, dtype=time_to_obs.dtype - ).expand(*batch, T) - idx_to_gather = idx_to_gather + time_to_obs - # idx_to_gather looks like tensor([[ 2, 3, 4, 5, 5, 5, 8, 9, 10, 10, 10]]) - # with a done state tensor([[ 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]]) - # meaning that the first obs will be replaced by the third, the second by the fourth etc. - # The fifth remains the fifth as it is terminal - tensordict_gather = ( - tensordict["next"].exclude("reward", "done").gather(-1, idx_to_gather) - ) + summed_reward, time_to_obs = _get_reward(gamma, reward, done, n_steps) + summed_rewards.append(summed_reward) + + idx_to_gather = torch.arange( + T, device=time_to_obs.device, dtype=time_to_obs.dtype + ).expand(*batch, T) + idx_to_gather = idx_to_gather + time_to_obs - tensordict.set("steps_to_next_obs", time_to_obs + 1) - tensordict.rename_key_(("next", "reward"), ("next", "original_reward")) - tensordict.get("next").update(tensordict_gather) - tensordict.set(("next", "reward"), summed_rewards) - tensordict.set("gamma", self.gamma ** (time_to_obs + 1)) - nonterminal = time_to_obs != 0 - if mask is not None: - mask = mask.view(*batch, T) - nonterminal[~mask] = False - tensordict.set("nonterminal", nonterminal) - if tensordict.ndim != ndim: - tensordict = tensordict.apply( - lambda x: x.transpose(ndim - 1, tensordict.ndim - 1), - batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape, - ) - tensordict.batch_size = tensordict.batch_size[:ndim] - return tensordict + # idx_to_gather looks like tensor([[ 2, 3, 4, 5, 5, 5, 8, 9, 10, 10, 10]]) + # with a done state tensor([[ 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]]) + # meaning that the first obs will be replaced by the third, the second by the fourth etc. + # The fifth remains the fifth as it is terminal + tensordict_gather = ( + tensordict.get("next") + .exclude(*reward_keys, *done_keys) + .gather(-1, idx_to_gather) + ) + + tensordict.set("steps_to_next_obs", time_to_obs + 1) + for reward_key, summed_reward in zip(reward_keys, summed_rewards): + tensordict.rename_key_(("next", reward_key), ("next", "original_reward")) + tensordict.set(("next", reward_key), summed_reward) + + tensordict.get("next").update(tensordict_gather) + tensordict.set("gamma", gamma ** (time_to_obs + 1)) + nonterminal = time_to_obs != 0 + if mask is not None: + mask = mask.view(*batch, T) + nonterminal[~mask] = False + tensordict.set("nonterminal", nonterminal) + if tensordict.ndim != ndim: + tensordict = tensordict.apply( + lambda x: x.transpose(ndim - 1, tensordict.ndim - 1), + batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape, + ) + tensordict.batch_size = tensordict.batch_size[:ndim] + return tensordict diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index f7aaf1329f1..458f491dcb9 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -475,6 +475,8 @@ def add(self, data: Any) -> int: if self._transform is not None and len(self._transform): with _set_dispatch_td_nn_modules(is_tensor_collection(data)): data = self._transform.inv(data) + if data is None: + return torch.zeros((0, self._storage.ndim), dtype=torch.long) return self._add(data) def _add(self, data): @@ -517,6 +519,8 @@ def extend(self, data: Sequence) -> torch.Tensor: if self._transform is not None and len(self._transform): with _set_dispatch_td_nn_modules(is_tensor_collection(data)): data = self._transform.inv(data) + if data is None: + return torch.zeros((0, self._storage.ndim), dtype=torch.long) return self._extend(data) def update_priority( @@ -1008,6 +1012,8 @@ def add(self, data: TensorDictBase) -> int: if self._transform is not None: with _set_dispatch_td_nn_modules(is_tensor_collection(data)): data = self._transform.inv(data) + if data is None: + return torch.zeros((0, self._storage.ndim), dtype=torch.long) index = super()._add(data) if index is not None: @@ -1026,6 +1032,8 @@ def extend(self, tensordicts: TensorDictBase) -> torch.Tensor: ) if self._transform is not None: tensordicts = self._transform.inv(tensordicts) + if tensordicts is None: + return torch.zeros((0, self._storage.ndim), dtype=torch.long) index = super()._extend(tensordicts) self._set_index_in_td(tensordicts, index) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 86ac25c633c..729171183b1 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -57,6 +57,7 @@ gSDENoise, InitTracker, KLRewardTransform, + MultiStepTransform, NoopResetEnv, ObservationNorm, ObservationTransform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 35769a69b8d..4056c32692a 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -5,6 +5,7 @@ from .gym_transforms import EndOfLifeTransform from .r3m import R3MTransform +from .rb_transforms import MultiStepTransform from .rlhf import KLRewardTransform from .transforms import ( ActionMask, diff --git a/torchrl/envs/transforms/rb_transforms.py b/torchrl/envs/transforms/rb_transforms.py new file mode 100644 index 00000000000..2a3cd8ec319 --- /dev/null +++ b/torchrl/envs/transforms/rb_transforms.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from typing import List + +import torch + +from tensordict import NestedKey, TensorDictBase +from torchrl.data.postprocs.postprocs import _multi_step_func +from torchrl.envs.transforms.transforms import Transform + + +class MultiStepTransform(Transform): + """A MultiStep transformation for ReplayBuffers. + + This transform keeps the previous ``n_steps`` observations in a local buffer. + The inverse transform (called during :meth:`~torchrl.data.ReplayBuffer.extend`) + outputs the transformed previous ``n_steps`` with the ``T-n_steps`` current + frames. + + All entries in the ``"next"`` tensordict that are not part of the ``done_keys`` + or ``reward_keys`` will be mapped to their respective ``t + n_steps - 1`` + correspondent. + + This transform is a more hyperparameter resistant version of + :class:`~torchrl.data.postprocs.postprocs.MultiStep`: + the replay buffer transform will make the multi-step transform insensitive + to the collectors hyperparameters, whereas the post-process + version will output results that are sensitive to these + (because collectors have no memory of previous output). + + Args: + n_steps (int): Number of steps in multi-step. The number of steps can be + dynamically changed by changing the ``n_steps`` attribute of this + transform. + gamma (float): Discount factor. + + Keyword Args: + reward_keys (list of NestedKey, optional): the reward keys in the input tensordict. + The reward entries indicated by these keys will be accumulated and discounted + across ``n_steps`` steps in the future. A corresponding ``_orig`` + entry will be written in the ``"next"`` entry of the output tensordict + to keep track of the original value of the reward. + Defaults to ``["reward"]``. + done_key (NestedKey, optional): the done key in the input tensordict, used to indicate + an end of trajectory. + Defaults to ``"done"``. + done_keys (list of NestedKey, optional): the list of end keys in the input tensordict. + All the entries indicated by these keys will be left untouched by the transform. + Defaults to ``["done", "truncated", "terminated"]``. + mask_key (NestedKey, optional): the mask key in the input tensordict. + The mask represents the valid frames in the input tensordict and + should have a shape that allows the input tensordict to be masked + with. + Defaults to ``"mask"``. + + Examples: + >>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter, MultiStepTransform, SerialEnv + >>> from torchrl.data import ReplayBuffer, LazyTensorStorage + >>> rb = ReplayBuffer( + ... storage=LazyTensorStorage(100, ndim=2), + ... transform=MultiStepTransform(n_steps=3, gamma=0.95) + ... ) + >>> base_env = SerialEnv(2, lambda: GymEnv("CartPole")) + >>> env = TransformedEnv(base_env, StepCounter()) + >>> _ = env.set_seed(0) + >>> _ = torch.manual_seed(0) + >>> tdreset = env.reset() + >>> for _ in range(100): + ... rollout = env.rollout(max_steps=50, break_when_any_done=False, + ... tensordict=tdreset, auto_reset=False) + ... indices = rb.extend(rollout) + ... tdreset = rollout[..., -1]["next"] + >>> print("step_count", rb[:]["step_count"][:, :5]) + step_count tensor([[[ 9], + [10], + [11], + [12], + [13]], + + [[12], + [13], + [14], + [15], + [16]]]) + >>> # The next step_count is 3 steps in the future + >>> print("next step_count", rb[:]["next", "step_count"][:, :5]) + next step_count tensor([[[13], + [14], + [15], + [16], + [17]], + + [[16], + [17], + [18], + [19], + [20]]]) + + """ + + ENV_ERR = ( + "The MultiStepTransform is only an inverse transform and can " + "be applied exclusively to replay buffers." + ) + + def __init__( + self, + n_steps, + gamma, + *, + reward_keys: List[NestedKey] | None = None, + done_key: NestedKey | None = None, + done_keys: List[NestedKey] | None = None, + mask_key: NestedKey | None = None, + ): + super().__init__() + self.n_steps = n_steps + self.reward_keys = reward_keys + self.done_key = done_key + self.done_keys = done_keys + self.mask_key = mask_key + self.gamma = gamma + self._buffer = None + self._validated = False + + @property + def n_steps(self): + """The look ahead window of the transform. + + This value can be dynamically edited during training. + """ + return self._n_steps + + @n_steps.setter + def n_steps(self, value): + if not isinstance(value, int) or not (value >= 1): + raise ValueError( + "The value of n_steps must be a strictly positive integer." + ) + self._n_steps = value + + @property + def done_key(self): + return self._done_key + + @done_key.setter + def done_key(self, value): + if value is None: + value = "done" + self._done_key = value + + @property + def done_keys(self): + return self._done_keys + + @done_keys.setter + def done_keys(self, value): + if value is None: + value = ["done", "terminated", "truncated"] + self._done_keys = value + + @property + def reward_keys(self): + return self._reward_keys + + @reward_keys.setter + def reward_keys(self, value): + if value is None: + value = [ + "reward", + ] + self._reward_keys = value + + @property + def mask_key(self): + return self._mask_key + + @mask_key.setter + def mask_key(self, value): + if value is None: + value = "mask" + self._mask_key = value + + def _validate(self): + if self.parent is not None: + raise ValueError(self.ENV_ERR) + self._validated = True + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self._validated: + self._validate() + + total_cat = self._append_tensordict(tensordict) + if total_cat.shape[-1] >= self.n_steps: + out = _multi_step_func( + total_cat, + done_key=self.done_key, + done_keys=self.done_keys, + reward_keys=self.reward_keys, + mask_key=self.mask_key, + n_steps=self.n_steps, + gamma=self.gamma, + ) + return out[..., : -self.n_steps] + + def _append_tensordict(self, data): + if self._buffer is None: + total_cat = data + self._buffer = data[..., -self.n_steps :].copy() + else: + total_cat = torch.cat([self._buffer, data], -1) + self._buffer = total_cat[..., -self.n_steps :].copy() + return total_cat diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7d3a7cb0ab9..124250157ec 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6255,7 +6255,7 @@ class Reward2GoTransform(Transform): As the :class:`~.Reward2GoTransform` is only an inverse transform the ``in_keys`` will be directly used for the ``in_keys_inv``. The reward-to-go can be only calculated once the episode is finished. Therefore, the transform should be applied to the replay buffer - and not to the collector. + and not to the collector or within an environment. Args: gamma (float or torch.Tensor): the discount factor. Defaults to 1.0. @@ -6354,7 +6354,7 @@ class Reward2GoTransform(Transform): ENV_ERR = ( "The Reward2GoTransform is only an inverse transform and can " - "only be applied to the replay buffer and not to the collector or the environment." + "only be applied to the replay buffer." ) def __init__(