diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 093ead56164..a0a7cb23dfd 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -472,6 +472,35 @@ def ndimension(self): def ndim(self): return self.ndimension() + def append_transform( + self, + transform: "Transform" # noqa: F821 + | Callable[[TensorDictBase], TensorDictBase], + ) -> None: + """Returns a transformed environment where the callable/transform passed is applied. + + Args: + transform (Transform or Callable[[TensorDictBase], TensorDictBase]): the transform to apply + to the environment. + + Examples: + >>> from torchrl.envs import GymEnv + >>> import torch + >>> env = GymEnv("CartPole-v1") + >>> loc = 0.5 + >>> scale = 1.0 + >>> transform = lambda data: data.set("observation", (data.get("observation") - loc)/scale) + >>> env = env.append_transform(transform=transform) + >>> print(env) + TransformedEnv( + env=GymEnv(env=CartPole-v1, batch_size=torch.Size([]), device=cpu), + transform=_CallableTransform(keys=[])) + + """ + from torchrl.envs.transforms.transforms import TransformedEnv + + return TransformedEnv(self, transform) + # Parent specs: input and output spec. @property def input_spec(self) -> TensorSpec: