From e7630f1b5e9e02eff9b150dd7a1242435eb8e623 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 24 Oct 2023 06:33:04 -0400 Subject: [PATCH] [Feature] Exclude all private keys in collectors (#1644) --- test/test_collector.py | 17 ++++++++++++----- torchrl/collectors/collectors.py | 12 +++++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 9009f33b303..8667ea24790 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -34,7 +34,7 @@ MultiKeyCountingEnvPolicy, NestedCountingEnv, ) -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn @@ -939,17 +939,22 @@ def create_env(): [MultiSyncDataCollector, MultiaSyncDataCollector, SyncDataCollector], ) @pytest.mark.parametrize("exclude", [True, False]) -def test_excluded_keys(collector_class, exclude): +@pytest.mark.parametrize("out_key", ["_dummy", ("out", "_dummy"), ("_out", "dummy")]) +def test_excluded_keys(collector_class, exclude, out_key): if not exclude and collector_class is not SyncDataCollector: pytest.skip("defining _exclude_private_keys is not possible") def make_env(): - return ContinuousActionVecMockEnv() + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) dummy_env = make_env() obs_spec = dummy_env.observation_spec["observation"] policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) - policy = Actor(policy_module, spec=dummy_env.action_spec) + policy = TensorDictModule( + policy_module, in_keys=["observation"], out_keys=["action"] + ) + copier = TensorDictModule(lambda x: x, in_keys=["observation"], out_keys=[out_key]) + policy = TensorDictSequential(policy, copier) policy_explore = OrnsteinUhlenbeckProcessWrapper(policy) collector_kwargs = { @@ -966,11 +971,13 @@ def make_env(): collector = collector_class(**collector_kwargs) collector._exclude_private_keys = exclude for b in collector: - keys = b.keys() + keys = set(b.keys()) if exclude: assert not any(key.startswith("_") for key in keys) + assert out_key not in b.keys(True, True) else: assert any(key.startswith("_") for key in keys) + assert out_key in b.keys(True, True) break collector.shutdown() dummy_env.close() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 2bbf1f927a0..e92172e3437 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -758,8 +758,18 @@ def iterator(self) -> Iterator[TensorDictBase]: if self.postproc is not None: tensordict_out = self.postproc(tensordict_out) if self._exclude_private_keys: + + def is_private(key): + if isinstance(key, str) and key.startswith("_"): + return True + if isinstance(key, tuple) and any( + _key.startswith("_") for _key in key + ): + return True + return False + excluded_keys = [ - key for key in tensordict_out.keys() if key.startswith("_") + key for key in tensordict_out.keys(True) if is_private(key) ] tensordict_out = tensordict_out.exclude( *excluded_keys, inplace=True