Skip to content

Commit

Permalink
[Feature] Auto-resetting envs (#2073)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 16, 2024
1 parent d2cfd28 commit 8570bd3
Show file tree
Hide file tree
Showing 9 changed files with 743 additions and 43 deletions.
78 changes: 78 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,82 @@ single agent standards.
MarlGroupMapType
check_marl_grouping

Auto-resetting Envs
-------------------

Auto-resetting environments are environments where calls to :meth:`~torchrl.envs.EnvBase.reset` are not expected when
the environment reaches a ``"done"`` state during a rollout, as the reset happens automatically.
Usually, in such cases the observations delivered with the done and reward (which effectively result from performing the
action in the environment) are actually the first observations of a new episode, and not the last observations of the
current episode.

To handle these cases, torchrl provides a :class:`~torchrl.envs.AutoResetTransform` that will copy the observations
that result from the call to `step` to the next `reset` and skip the calls to `reset` during rollouts (in both
:meth:`~torchrl.envs.EnvBase.rollout` and :class:`~torchrl.collectors.SyncDataCollector` iterations).
This transform class also provides a fine-grained control over the behaviour to be adopted for the invalid observations,
which can be masked with `"nan"` or any other values, or not masked at all.

To tell torchrl that an environment is auto-resetting, it is sufficient to provide an ``auto_reset`` argument
during construction. If provided, an ``auto_reset_replace`` argument can also control whether the values of the last
observation of an episode should be replaced with some placeholder or not.

>>> from torchrl.envs import GymEnv
>>> from torchrl.envs import set_gym_backend
>>> import torch
>>> torch.manual_seed(0)
>>>
>>> class AutoResettingGymEnv(GymEnv):
... def _step(self, tensordict):
... tensordict = super()._step(tensordict)
... if tensordict["done"].any():
... td_reset = super().reset()
... tensordict.update(td_reset.exclude(*self.done_keys))
... return tensordict
...
... def _reset(self, tensordict=None):
... if tensordict is not None and "_reset" in tensordict:
... return tensordict.copy()
... return super()._reset(tensordict)
>>>
>>> with set_gym_backend("gym"):
... env = AutoResettingGymEnv("CartPole-v1", auto_reset=True, auto_reset_replace=True)
... env.set_seed(0)
... r = env.rollout(30, break_when_any_done=False)
>>> print(r["next", "done"].squeeze())
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, True, False, False, False, False, False, False,
False, False, False, False, False, True, False, False, False, False])
>>> print("observation after reset are set as nan", r["next", "observation"])
observation after reset are set as nan tensor([[-4.3633e-02, -1.4877e-01, 1.2849e-02, 2.7584e-01],
[-4.6609e-02, 4.6166e-02, 1.8366e-02, -1.2761e-02],
[-4.5685e-02, 2.4102e-01, 1.8111e-02, -2.9959e-01],
[-4.0865e-02, 4.5644e-02, 1.2119e-02, -1.2542e-03],
[-3.9952e-02, 2.4059e-01, 1.2094e-02, -2.9009e-01],
[-3.5140e-02, 4.3554e-01, 6.2920e-03, -5.7893e-01],
[-2.6429e-02, 6.3057e-01, -5.2867e-03, -8.6963e-01],
[-1.3818e-02, 8.2576e-01, -2.2679e-02, -1.1640e+00],
[ 2.6972e-03, 1.0212e+00, -4.5959e-02, -1.4637e+00],
[ 2.3121e-02, 1.2168e+00, -7.5232e-02, -1.7704e+00],
[ 4.7457e-02, 1.4127e+00, -1.1064e-01, -2.0854e+00],
[ 7.5712e-02, 1.2189e+00, -1.5235e-01, -1.8289e+00],
[ 1.0009e-01, 1.0257e+00, -1.8893e-01, -1.5872e+00],
[ nan, nan, nan, nan],
[-3.9405e-02, -1.7766e-01, -1.0403e-02, 3.0626e-01],
[-4.2959e-02, -3.7263e-01, -4.2775e-03, 5.9564e-01],
[-5.0411e-02, -5.6769e-01, 7.6354e-03, 8.8698e-01],
[-6.1765e-02, -7.6292e-01, 2.5375e-02, 1.1820e+00],
[-7.7023e-02, -9.5836e-01, 4.9016e-02, 1.4826e+00],
[-9.6191e-02, -7.6387e-01, 7.8667e-02, 1.2056e+00],
[-1.1147e-01, -9.5991e-01, 1.0278e-01, 1.5219e+00],
[-1.3067e-01, -7.6617e-01, 1.3322e-01, 1.2629e+00],
[-1.4599e-01, -5.7298e-01, 1.5848e-01, 1.0148e+00],
[-1.5745e-01, -7.6982e-01, 1.7877e-01, 1.3527e+00],
[-1.7285e-01, -9.6668e-01, 2.0583e-01, 1.6956e+00],
[ nan, nan, nan, nan],
[-4.3962e-02, 1.9845e-01, -4.5015e-02, -2.5903e-01],
[-3.9993e-02, 3.9418e-01, -5.0196e-02, -5.6557e-01],
[-3.2109e-02, 5.8997e-01, -6.1507e-02, -8.7363e-01],
[-2.0310e-02, 3.9574e-01, -7.8980e-02, -6.0090e-01]])


Transforms
Expand Down Expand Up @@ -580,6 +656,8 @@ to be able to create this other composition:
Transform
TransformedEnv
ActionMask
AutoResetEnv
AutoResetTransform
BatchSizeTransform
BinarizeReward
BurnInTransform
Expand Down
84 changes: 77 additions & 7 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchrl.data.utils import consolidate_spec
from torchrl.envs.common import EnvBase
from torchrl.envs.model_based.common import ModelBasedEnvBase
from torchrl.envs.utils import _terminated_or_truncated

spec_dict = {
"bounded": BoundedTensorSpec,
Expand Down Expand Up @@ -1407,7 +1408,9 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
count[:] = self.start_val

self.register_buffer("count", count)
self._make_specs()

def _make_specs(self):
obs_specs = []
action_specs = []
for index in range(self.n_nested_dim):
Expand All @@ -1419,13 +1422,7 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):

self.unbatched_observation_spec = CompositeSpec(
lazy=obs_spec_unlazy,
state=UnboundedContinuousTensorSpec(
shape=(
64,
64,
3,
)
),
state=UnboundedContinuousTensorSpec(shape=(64, 64, 3)),
device=self.device,
)

Expand Down Expand Up @@ -1828,3 +1825,76 @@ def _step(

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)


class AutoResettingCountingEnv(CountingEnv):
def _step(self, tensordict):
tensordict = super()._step(tensordict)
if tensordict["done"].any():
td_reset = super().reset()
tensordict.update(td_reset.exclude(*self.done_keys))
return tensordict

def _reset(self, tensordict=None):
if tensordict is not None and "_reset" in tensordict:
raise RuntimeError
return super()._reset(tensordict)


class AutoResetHeteroCountingEnv(HeterogeneousCountingEnv):
def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
super().__init__(**kwargs)
self.n_nested_dim = 3
self.max_steps = max_steps
self.start_val = start_val

count = torch.zeros(
(*self.batch_size, self.n_nested_dim, 1),
device=self.device,
dtype=torch.int,
)
count[:] = self.start_val

self.register_buffer("count", count)
self._make_specs()

def _step(self, tensordict):
for i in range(self.n_nested_dim):
action = tensordict["lazy"][..., i]["action"]
action = action[..., 0].to(torch.bool)
self.count[..., i, 0] += action

td = self.observation_spec.zero()
for done_key in self.done_keys:
td[done_key] = self.count > self.max_steps

any_done = _terminated_or_truncated(
td,
full_done_spec=self.output_spec["full_done_spec"],
key=None,
)
if any_done:
self.count[td["lazy", "done"]] = 0

for i in range(self.n_nested_dim):
lazy = tensordict["lazy"][..., i]
for obskey in self.observation_spec.keys(True, True):
if isinstance(obskey, tuple) and obskey[0] == "lazy":
lazy[obskey[1:]] += expand_right(
self.count[..., i, 0], lazy[obskey[1:]].shape
).clone()
td.update(self.output_spec["full_done_spec"].zero())
td.update(self.output_spec["full_reward_spec"].zero())

assert td.batch_size == self.batch_size
return td

def _reset(self, tensordict=None):
if tensordict is not None and self.reset_keys[0] in tensordict.keys(True):
raise RuntimeError
self.count[:] = self.start_val

reset_td = self.observation_spec.zero()
reset_td.update(self.full_done_spec.zero())
assert reset_td.batch_size == self.batch_size
return reset_td
159 changes: 159 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
)
from mocking_classes import (
ActionObsMergeLinear,
AutoResetHeteroCountingEnv,
AutoResettingCountingEnv,
ContinuousActionConvMockEnv,
ContinuousActionConvMockEnvNumpy,
ContinuousActionVecMockEnv,
Expand Down Expand Up @@ -80,6 +82,7 @@
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
from torchrl.envs.transforms.transforms import AutoResetEnv, AutoResetTransform
from torchrl.envs.utils import (
_StepMDP,
_terminated_or_truncated,
Expand Down Expand Up @@ -2868,6 +2871,162 @@ def test_stackable():
assert _stackable(*stack)


class TestAutoReset:
def test_auto_reset(self):
policy = lambda td: td.set(
"action", torch.ones((*td.shape, 1), dtype=torch.int64)
)

env = AutoResettingCountingEnv(4, auto_reset=True)
assert isinstance(env, TransformedEnv) and isinstance(
env.transform, AutoResetTransform
)
r = env.rollout(20, policy, break_when_any_done=False)
assert r.shape == torch.Size([20])
assert r["next", "done"].sum() == 4
assert (r["next", "observation"][r["next", "done"].squeeze()] == -1).all(), r[
"next", "observation"
][r["next", "done"].squeeze()]
assert (
r[..., 1:]["observation"][r[..., :-1]["next", "done"].squeeze()] == 0
).all()
r = env.rollout(20, policy, break_when_any_done=True)
assert r["next", "done"].sum() == 1
assert not r["done"].any()

def test_auto_reset_transform(self):
policy = lambda td: td.set(
"action", torch.ones((*td.shape, 1), dtype=torch.int64)
)
env = TransformedEnv(
AutoResettingCountingEnv(4, auto_reset=True), StepCounter()
)
assert isinstance(env, TransformedEnv) and isinstance(
env.base_env.transform, AutoResetTransform
)
r = env.rollout(20, policy, break_when_any_done=False)
assert r.shape == torch.Size([20])
assert r["next", "done"].sum() == 4
assert (r["next", "observation"][r["next", "done"].squeeze()] == -1).all()
assert (
r[..., 1:]["observation"][r[..., :-1]["next", "done"].squeeze()] == 0
).all()
r = env.rollout(20, policy, break_when_any_done=True)
assert r["next", "done"].sum() == 1
assert not r["done"].any()

def test_auto_reset_serial(self):
policy = lambda td: td.set(
"action", torch.ones((*td.shape, 1), dtype=torch.int64)
)
env = SerialEnv(
2, functools.partial(AutoResettingCountingEnv, 4, auto_reset=True)
)
r = env.rollout(20, policy, break_when_any_done=False)
assert r.shape == torch.Size([2, 20])
assert r["next", "done"].sum() == 8
assert (r["next", "observation"][r["next", "done"].squeeze()] == -1).all()
assert (
r[..., 1:]["observation"][r[..., :-1]["next", "done"].squeeze()] == 0
).all()
r = env.rollout(20, policy, break_when_any_done=True)
assert r["next", "done"].sum() == 2
assert not r["done"].any()

def test_auto_reset_serial_hetero(self):
policy = lambda td: td.set(
"action", torch.ones((*td.shape, 1), dtype=torch.int64)
)
env = SerialEnv(
2,
[
functools.partial(AutoResettingCountingEnv, 4, auto_reset=True),
functools.partial(AutoResettingCountingEnv, 5, auto_reset=True),
],
)
r = env.rollout(20, policy, break_when_any_done=False)
assert r.shape == torch.Size([2, 20])
assert (r["next", "observation"][r["next", "done"].squeeze()] == -1).all()
assert (
r[..., 1:]["observation"][r[..., :-1]["next", "done"].squeeze()] == 0
).all()
assert not r["done"].any()

def test_auto_reset_parallel(self):
policy = lambda td: td.set(
"action", torch.ones((*td.shape, 1), dtype=torch.int64)
)
env = ParallelEnv(
2,
functools.partial(AutoResettingCountingEnv, 4, auto_reset=True),
mp_start_method="fork",
)
r = env.rollout(20, policy, break_when_any_done=False)
assert r.shape == torch.Size([2, 20])
assert r["next", "done"].sum() == 8
assert (r["next", "observation"][r["next", "done"].squeeze()] == -1).all()
assert (
r[..., 1:]["observation"][r[..., :-1]["next", "done"].squeeze()] == 0
).all()
r = env.rollout(20, policy, break_when_any_done=True)
assert r["next", "done"].sum() == 2
assert not r["done"].any()

def test_auto_reset_parallel_hetero(self):
policy = lambda td: td.set(
"action", torch.ones((*td.shape, 1), dtype=torch.int64)
)
env = ParallelEnv(
2,
[
functools.partial(AutoResettingCountingEnv, 4, auto_reset=True),
functools.partial(AutoResettingCountingEnv, 5, auto_reset=True),
],
mp_start_method="fork",
)
r = env.rollout(20, policy, break_when_any_done=False)
assert r.shape == torch.Size([2, 20])
assert (r["next", "observation"][r["next", "done"].squeeze()] == -1).all()
assert (
r[..., 1:]["observation"][r[..., :-1]["next", "done"].squeeze()] == 0
).all()
assert not r["done"].any()

def test_auto_reset_heterogeneous_env(self):
torch.manual_seed(0)
env = TransformedEnv(
AutoResetHeteroCountingEnv(4, auto_reset=True), StepCounter()
)

def policy(td):
return td.update(
env.full_action_spec.zero().apply(lambda x: x.bernoulli_(0.5))
)

assert isinstance(env.base_env, AutoResetEnv) and isinstance(
env.base_env.transform, AutoResetTransform
)
check_env_specs(env)
r = env.rollout(40, policy, break_when_any_done=False)
assert (r["next", "lazy", "step_count"] - 1 == r["lazy", "step_count"]).all()
done = r["next", "lazy", "done"].squeeze(-1)[:-1]
assert (
r["next", "lazy", "step_count"][1:][~done]
== r["next", "lazy", "step_count"][:-1][~done] + 1
).all()
assert (
r["next", "lazy", "step_count"][1:][done]
!= r["next", "lazy", "step_count"][:-1][done] + 1
).all()
done_split = r["next", "lazy", "done"].unbind(1)
lazy_slit = r["next", "lazy"].unbind(1)
lazy_roots = r["lazy"].unbind(1)
for lazy, lazy_root, done in zip(lazy_slit, lazy_roots, done_split):
assert lazy["lidar"][done.squeeze()].isnan().all()
assert not lazy["lidar"][~done.squeeze()].isnan().any()
assert (lazy_root["lidar"][1:][done[:-1].squeeze()] == 0).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading

0 comments on commit 8570bd3

Please sign in to comment.