diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 7808de80ad2..d885a3ae832 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2680,7 +2680,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """ if self._simple_done: done = tensordict._get_str("done", default=None) - tensordict._set_str("done", done.clone(), validated=True, inplace=False) + tensordict._set_str("_reset", done.clone(), validated=True, inplace=False) any_done = done.any() else: any_done = _terminated_or_truncated( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index fb7e778f957..7d3a7cb0ab9 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -733,7 +733,6 @@ def input_spec(self) -> TensorSpec: def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # No need to clone here because inv does it already # tensordict = tensordict.clone(False) - next_preset = tensordict.get("next", None) tensordict_in = self.transform.inv(tensordict) next_tensordict = self.base_env._step(tensordict_in)