Skip to content

Commit

Permalink
[Feature] Exclude all private keys in collectors (#1644)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 24, 2023
1 parent 2e32c10 commit e7630f1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
17 changes: 12 additions & 5 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.tensordict import assert_allclose_td, TensorDict

from torch import nn
Expand Down Expand Up @@ -939,17 +939,22 @@ def create_env():
[MultiSyncDataCollector, MultiaSyncDataCollector, SyncDataCollector],
)
@pytest.mark.parametrize("exclude", [True, False])
def test_excluded_keys(collector_class, exclude):
@pytest.mark.parametrize("out_key", ["_dummy", ("out", "_dummy"), ("_out", "dummy")])
def test_excluded_keys(collector_class, exclude, out_key):
if not exclude and collector_class is not SyncDataCollector:
pytest.skip("defining _exclude_private_keys is not possible")

def make_env():
return ContinuousActionVecMockEnv()
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

dummy_env = make_env()
obs_spec = dummy_env.observation_spec["observation"]
policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1])
policy = Actor(policy_module, spec=dummy_env.action_spec)
policy = TensorDictModule(
policy_module, in_keys=["observation"], out_keys=["action"]
)
copier = TensorDictModule(lambda x: x, in_keys=["observation"], out_keys=[out_key])
policy = TensorDictSequential(policy, copier)
policy_explore = OrnsteinUhlenbeckProcessWrapper(policy)

collector_kwargs = {
Expand All @@ -966,11 +971,13 @@ def make_env():
collector = collector_class(**collector_kwargs)
collector._exclude_private_keys = exclude
for b in collector:
keys = b.keys()
keys = set(b.keys())
if exclude:
assert not any(key.startswith("_") for key in keys)
assert out_key not in b.keys(True, True)
else:
assert any(key.startswith("_") for key in keys)
assert out_key in b.keys(True, True)
break
collector.shutdown()
dummy_env.close()
Expand Down
12 changes: 11 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,18 @@ def iterator(self) -> Iterator[TensorDictBase]:
if self.postproc is not None:
tensordict_out = self.postproc(tensordict_out)
if self._exclude_private_keys:

def is_private(key):
if isinstance(key, str) and key.startswith("_"):
return True
if isinstance(key, tuple) and any(
_key.startswith("_") for _key in key
):
return True
return False

excluded_keys = [
key for key in tensordict_out.keys() if key.startswith("_")
key for key in tensordict_out.keys(True) if is_private(key)
]
tensordict_out = tensordict_out.exclude(
*excluded_keys, inplace=True
Expand Down

0 comments on commit e7630f1

Please sign in to comment.