diff --git a/test/test_env.py b/test/test_env.py index 36671db8922..5fd2f5c60c3 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1727,7 +1727,7 @@ def test_multi_purpose_env(self, serial): env = SerialEnv(2, ContinuousActionVecMockEnv) else: env = ContinuousActionVecMockEnv() - rollout = env.rollout(10) + env.rollout(10) assert env._step_mdp.validate(None) c = SyncDataCollector( env, env.rand_action, frames_per_batch=10, total_frames=20 @@ -1736,7 +1736,18 @@ def test_multi_purpose_env(self, serial): pass assert ("collector", "traj_ids") in data.keys(True) assert env._step_mdp.validate(None) - rollout = env.rollout(10) + env.rollout(10) + + # An exception will be raised when the collector sees extra keys + if serial: + env = SerialEnv(2, ContinuousActionVecMockEnv) + else: + env = ContinuousActionVecMockEnv() + c = SyncDataCollector( + env, env.rand_action, frames_per_batch=10, total_frames=20 + ) + for data in c: # noqa: B007 + pass @pytest.mark.parametrize("device", get_default_devices()) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 12d282cc86d..6094ccf9f77 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -167,7 +167,7 @@ def validate(self, tensordict): ) actual = set(tensordict.keys(True, True)) expected = set(expected) - self.validated = expected.union(actual) == expected + self.validated = expected.intersection(actual) == expected if not self.validated: warnings.warn( "The expected key set and actual key set differ. "