diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index d4a67e7d3a9..415e19a1f7c 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -109,7 +109,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): ) # copy action from the input tensordict to the output - transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec)) + transformed_env.append_transform(TensorDictPrimer(base_env.full_action_spec)) transformed_env.append_transform(DoubleToFloat()) obsnorm = ObservationNorm( diff --git a/test/test_transforms.py b/test/test_transforms.py index cc3ca40b059..44ebce72c5c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7408,7 +7408,7 @@ def make_env(): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=Unbounded([2, 4])), + TensorDictPrimer(mykey=Unbounded([4])), ) try: check_env_specs(env) @@ -7423,11 +7423,39 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): pass @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) - def test_trans_serial_env_check(self, spec_shape): - env = TransformedEnv( - SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=Unbounded(spec_shape)), - ) + @pytest.mark.parametrize("expand_specs", [True, False, None]) + def test_trans_serial_env_check(self, spec_shape, expand_specs): + if expand_specs is None: + with pytest.warns(FutureWarning, match=""): + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + mykey=Unbounded(spec_shape), expand_specs=expand_specs + ), + ) + env.observation_spec + elif expand_specs is True: + shape = spec_shape[:-1] + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + Composite(mykey=Unbounded(spec_shape), shape=shape), + expand_specs=expand_specs, + ), + ) + else: + # If we don't expand, we can't use [4] + env = TransformedEnv( + SerialEnv(2, ContinuousActionVecMockEnv), + TensorDictPrimer( + mykey=Unbounded(spec_shape), expand_specs=expand_specs + ), + ) + if spec_shape == [4]: + with pytest.raises(ValueError): + env.observation_spec + return + check_env_specs(env) assert "mykey" in env.reset().keys() r = env.rollout(3) @@ -10310,9 +10338,8 @@ def _make_transform_env(self, out_key, base_env): transform = KLRewardTransform(actor, out_keys=out_key) return Compose( TensorDictPrimer( - primers={ - "sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1]) - } + sample_log_prob=Unbounded(shape=base_env.action_spec.shape[:-1]), + shape=base_env.shape, ), transform, ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 14d4133412c..5b2dafc9b04 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4984,6 +4984,7 @@ def __init__( | Dict[NestedKey, float] | Dict[NestedKey, Callable] = None, reset_key: NestedKey | None = None, + expand_specs: bool = None, **kwargs, ): self.device = kwargs.pop("device", None) @@ -4995,8 +4996,10 @@ def __init__( ) kwargs = primers if not isinstance(kwargs, Composite): - kwargs = Composite(kwargs) + kwargs = Composite(**kwargs) self.primers = kwargs + self.expand_specs = expand_specs + if random and default_value: raise ValueError( "Setting random to True and providing a default_value are incompatible." @@ -5089,15 +5092,25 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: ) if self.primers.shape != observation_spec.shape: - if self.primers.shape == () and self.parent.batch_size != (): + if self.expand_specs: self.primers = self._expand_shape(self.primers) - else: + elif self.expand_specs is None: + warnings.warn( + f"expand_specs wasn't specified in the {type(self).__name__} constructor. " + f"The current behaviour is that the transform will attempt to set the shape of the composite " + f"spec, and if this can't be done it will be expanded. " + f"From v0.8, a mismatched shape between the spec of the transform and the env's batch_size " + f"will raise an exception.", + category=FutureWarning, + ) try: # We try to set the primer shape to the observation spec shape self.primers.shape = observation_spec.shape except ValueError: # If we fail, we expand them to that shape self.primers = self._expand_shape(self.primers) + else: + self.primers.shape = observation_spec.shape device = observation_spec.device observation_spec.update(self.primers.clone().to(device)) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 57bcac94cf4..68309c346cd 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -653,7 +653,8 @@ def make_tuple(key): { in_key1: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), in_key2: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), - } + }, + expand_specs=True, ) @property @@ -1467,7 +1468,8 @@ def make_tuple(key): return TensorDictPrimer( { in_key1: Unbounded(shape=(self.gru.num_layers, self.gru.hidden_size)), - } + }, + expand_specs=True, ) @property