Skip to content

Commit

Permalink
[Feature] flexible batch_locked for jumanji
Browse files Browse the repository at this point in the history
ghstack-source-id: e356b6511ff3da8a6c583747214cfa90f42c9083
Pull Request resolved: #2382
  • Loading branch information
vmoens committed Nov 8, 2024
1 parent 14b63e4 commit 35a7813
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 39 deletions.
27 changes: 23 additions & 4 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ def test_jumanji_seeding(self, envname):

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_jumanji_batch_size(self, envname, batch_size):
env = JumanjiEnv(envname, batch_size=batch_size)
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
env.set_seed(0)
tdreset = env.reset()
tdrollout = env.rollout(max_steps=50)
Expand All @@ -1616,7 +1616,7 @@ def test_jumanji_batch_size(self, envname, batch_size):

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_jumanji_spec_rollout(self, envname, batch_size):
env = JumanjiEnv(envname, batch_size=batch_size)
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
env.set_seed(0)
check_env_specs(env)

Expand All @@ -1627,7 +1627,7 @@ def test_jumanji_consistency(self, envname, batch_size):
import numpy as onp
from torchrl.envs.libs.jax_utils import _tree_flatten

env = JumanjiEnv(envname, batch_size=batch_size)
env = JumanjiEnv(envname, batch_size=batch_size, jit=True)
obs_keys = list(env.observation_spec.keys(True))
env.set_seed(1)
rollout = env.rollout(10)
Expand Down Expand Up @@ -1665,19 +1665,38 @@ def test_jumanji_consistency(self, envname, batch_size):
@pytest.mark.parametrize("batch_size", [[3], []])
def test_jumanji_rendering(self, envname, batch_size):
# check that this works with a batch-size
env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size)
env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size, jit=True)
env.set_seed(0)
env.transform.transform_observation_spec(env.base_env.observation_spec)

r = env.rollout(10)
pixels = r["pixels"]
if not isinstance(pixels, torch.Tensor):
pixels = torch.as_tensor(np.asarray(pixels))
assert batch_size
else:
assert not batch_size
assert pixels.unique().numel() > 1
assert pixels.dtype == torch.uint8

check_env_specs(env)

@pytest.mark.parametrize("jit", [True, False])
def test_jumanji_batch_unlocked(self, envname, jit):
torch.manual_seed(0)
env = JumanjiEnv(envname, jit=jit)
env.set_seed(0)
assert not env.batch_locked
reset = env.reset(TensorDict(batch_size=[16]))
assert reset.batch_size == (16,)
env.rand_step(reset)
r = env.rollout(
2000, auto_reset=False, tensordict=reset, break_when_all_done=True
)
assert r.batch_size[0] == 16
done = r["next", "done"]
assert done.any(-2).all() or (r.shape[-1] == 2000)


ENVPOOL_CLASSIC_CONTROL_ENVS = [
PENDULUM_VERSIONED(),
Expand Down
27 changes: 23 additions & 4 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,21 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
# sanity check
self._assert_tensordict_shape(tensordict)
partial_steps = None

if not self.batch_locked:
# Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here
partial_steps = tensordict.get("_step", None)
if partial_steps is not None:
if partial_steps.all():
partial_steps = None
else:
tensordict_batch_size = tensordict.batch_size
partial_steps = partial_steps.view(tensordict_batch_size)
tensordict = tensordict[partial_steps]
else:
tensordict_batch_size = self.batch_size

next_preset = tensordict.get("next", None)

next_tensordict = self._step(tensordict)
Expand All @@ -1512,6 +1527,10 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
next_preset.exclude(*next_tensordict.keys(True, True))
)
tensordict.set("next", next_tensordict)
if partial_steps is not None:
result = tensordict.new_zeros(tensordict_batch_size)
result[partial_steps] = tensordict
return result
return tensordict

@classmethod
Expand Down Expand Up @@ -2731,7 +2750,7 @@ def _rollout_stop_early(
if break_when_all_done:
if partial_steps is not True:
# At least one partial step has been done
del td_append["_partial_steps"]
del td_append["_step"]
td_append = torch.where(
partial_steps.view(td_append.shape), td_append, tensordicts[-1]
)
Expand All @@ -2757,17 +2776,17 @@ def _rollout_stop_early(
_terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_partial_steps",
key="_step",
write_full_false=False,
)
partial_step_curr = tensordict.get("_partial_steps", None)
partial_step_curr = tensordict.get("_step", None)
if partial_step_curr is not None:
partial_step_curr = ~partial_step_curr
partial_steps = partial_steps & partial_step_curr
if partial_steps is not True:
if not partial_steps.any():
break
tensordict.set("_partial_steps", partial_steps)
tensordict.set("_step", partial_steps)

if callback is not None:
callback(self, tensordict)
Expand Down
12 changes: 8 additions & 4 deletions torchrl/envs/libs/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,21 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:
return None


def _tensordict_to_object(tensordict: TensorDictBase, object_example):
def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size=None):
"""Converts a TensorDict to a namedtuple or a dataclass."""
from jax import dlpack as jax_dlpack, numpy as jnp

if batch_size is None:
batch_size = []
t = {}
_fields = _get_object_fields(object_example)
for name, example in _fields.items():
value = tensordict.get(name, None)
if isinstance(value, TensorDictBase):
t[name] = _tensordict_to_object(value, example)
t[name] = _tensordict_to_object(value, example, batch_size=batch_size)
elif value is None:
if isinstance(example, dict):
t[name] = _tensordict_to_object({}, example)
t[name] = _tensordict_to_object({}, example, batch_size=batch_size)
else:
t[name] = None
else:
Expand All @@ -140,7 +142,9 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example):
t[name] = value
else:
value = jnp.reshape(value, tuple(shape))
t[name] = value.view(example.dtype).reshape(example.shape)
t[name] = value.view(example.dtype).reshape(
(*batch_size, *example.shape)
)
return type(object_example)(**t)


Expand Down
Loading

0 comments on commit 35a7813

Please sign in to comment.