From cb188a8c81f15992734d672a000fdba3616033f9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 09:36:23 -0400 Subject: [PATCH] [Feature] flexible batch_locked for jumanji ghstack-source-id: b62e657f5d46af7b511363b4139379bde626d4e0 Pull Request resolved: https://github.com/pytorch/rl/pull/2382 --- test/test_libs.py | 25 +++++- torchrl/envs/common.py | 22 +++++- torchrl/envs/libs/jax_utils.py | 12 ++- torchrl/envs/libs/jumanji.py | 138 +++++++++++++++++++++++++++------ 4 files changed, 160 insertions(+), 37 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index a76cb610d69..4a333bfddcb 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1553,7 +1553,7 @@ def test_jumanji_seeding(self, envname): @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) def test_jumanji_batch_size(self, envname, batch_size): - env = JumanjiEnv(envname, batch_size=batch_size) + env = JumanjiEnv(envname, batch_size=batch_size, jit=True) env.set_seed(0) tdreset = env.reset() tdrollout = env.rollout(max_steps=50) @@ -1564,7 +1564,7 @@ def test_jumanji_batch_size(self, envname, batch_size): @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) def test_jumanji_spec_rollout(self, envname, batch_size): - env = JumanjiEnv(envname, batch_size=batch_size) + env = JumanjiEnv(envname, batch_size=batch_size, jit=True) env.set_seed(0) check_env_specs(env) @@ -1575,7 +1575,7 @@ def test_jumanji_consistency(self, envname, batch_size): import numpy as onp from torchrl.envs.libs.jax_utils import _tree_flatten - env = JumanjiEnv(envname, batch_size=batch_size) + env = JumanjiEnv(envname, batch_size=batch_size, jit=True) obs_keys = list(env.observation_spec.keys(True)) env.set_seed(1) rollout = env.rollout(10) @@ -1613,7 +1613,7 @@ def test_jumanji_consistency(self, envname, batch_size): @pytest.mark.parametrize("batch_size", [[3], []]) def test_jumanji_rendering(self, envname, batch_size): # check that this works with a batch-size - env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size) + env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size, jit=True) env.set_seed(0) env.transform.transform_observation_spec(env.base_env.observation_spec) @@ -1626,6 +1626,23 @@ def test_jumanji_rendering(self, envname, batch_size): check_env_specs(env) + @pytest.mark.parametrize("jit", [True, False]) + def test_jumanji_batch_unlocked(self, envname, jit): + torch.manual_seed(0) + env = JumanjiEnv(envname, jit=jit) + env.set_seed(0) + assert not env.batch_locked + reset = env.reset(TensorDict(batch_size=[16])) + assert reset.batch_size == (16,) + env.rand_step(reset) + t0 = time.time() + r = env.rollout( + 20, auto_reset=False, tensordict=reset, break_when_all_done=True + ) + assert r.batch_size[0] == 16 + done = r["next", "done"].float() + assert (done.cumprod(-2) == done).all() + ENVPOOL_CLASSIC_CONTROL_ENVS = [ PENDULUM_VERSIONED(), diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 2aacf76168b..10ef15737a6 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1473,6 +1473,16 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: """ # sanity check self._assert_tensordict_shape(tensordict) + if not self.batch_locked: + # Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here + partial_steps = tensordict.get("_step", None) + if partial_steps is not None: + if partial_steps.all(): + partial_steps = None + else: + tensordict_batch_size = tensordict.batch_size + partial_steps = partial_steps.view(tensordict_batch_size) + tensordict = tensordict[partial_steps] next_preset = tensordict.get("next", None) next_tensordict = self._step(tensordict) @@ -1485,6 +1495,10 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: next_preset.exclude(*next_tensordict.keys(True, True)) ) tensordict.set("next", next_tensordict) + if partial_steps is not None: + result = tensordict.new_zeros(tensordict_batch_size) + result[partial_steps] = tensordict + return result return tensordict @classmethod @@ -2696,7 +2710,7 @@ def _rollout_stop_early( if break_when_all_done: if partial_steps is not True: # At least one partial step has been done - del td_append["_partial_steps"] + del td_append["_step"] td_append = torch.where( partial_steps.view(td_append.shape), td_append, tensordicts[-1] ) @@ -2722,17 +2736,17 @@ def _rollout_stop_early( _terminated_or_truncated( tensordict, full_done_spec=self.output_spec["full_done_spec"], - key="_partial_steps", + key="_step", write_full_false=False, ) - partial_step_curr = tensordict.get("_partial_steps", None) + partial_step_curr = tensordict.get("_step", None) if partial_step_curr is not None: partial_step_curr = ~partial_step_curr partial_steps = partial_steps & partial_step_curr if partial_steps is not True: if not partial_steps.any(): break - tensordict.set("_partial_steps", partial_steps) + tensordict.set("_step", partial_steps) if callback is not None: callback(self, tensordict) diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 052f538f0c4..086533cb487 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -102,19 +102,21 @@ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase: return None -def _tensordict_to_object(tensordict: TensorDictBase, object_example): +def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size=None): """Converts a TensorDict to a namedtuple or a dataclass.""" from jax import dlpack as jax_dlpack, numpy as jnp + if batch_size is None: + batch_size = [] t = {} _fields = _get_object_fields(object_example) for name, example in _fields.items(): value = tensordict.get(name, None) if isinstance(value, TensorDictBase): - t[name] = _tensordict_to_object(value, example) + t[name] = _tensordict_to_object(value, example, batch_size=batch_size) elif value is None: if isinstance(example, dict): - t[name] = _tensordict_to_object({}, example) + t[name] = _tensordict_to_object({}, example, batch_size=batch_size) else: t[name] = None else: @@ -140,7 +142,9 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example): t[name] = value else: value = jnp.reshape(value, tuple(shape)) - t[name] = value.view(example.dtype).reshape(example.shape) + t[name] = value.view(example.dtype).reshape( + (*batch_size, *example.shape) + ) return type(object_example)(**t) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index dbbc980e8cc..5278683acba 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -128,6 +128,12 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender): Paper: https://arxiv.org/abs/2306.09884 + .. note:: For better performance, turn `jit` on when instantiating this class. + The `jit` attribute can also be flipped during code execution: + + >>> env.jit = True # Used jit + >>> env.jit = False # eager + Args: env (jumanji.env.Environment): the env to wrap. categorical_action_encoding (bool, optional): if ``True``, categorical @@ -136,6 +142,22 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender): Defaults to ``False``. Keyword Args: + batch_size (torch.Size, optional): the batch size of the environment. + With ``jumanji``, this indicates the number of vectorized environments. + If the batch-size is empty, the environment is not batch-locked and an arbitrary number + of environments can be executed simultaneously. + Defaults to ``torch.Size([])``. + + >>> import jumanji + >>> from torchrl.envs import JumanjiWrapper + >>> base_env = jumanji.make("Snake-v1") + >>> env = JumanjiWrapper(base_env) + >>> # Set the batch-size of the TensorDict instead of the env allows to control the number + >>> # of envs being run simultaneously + >>> tdreset = env.reset(TensorDict(batch_size=[32])) + >>> # Execute a rollout until all envs are done or max steps is reached, whichever comes first + >>> rollout = env.rollout(100, break_when_all_done=True, auto_reset=False, tensordict=tdreset) + from_pixels (bool, optional): Whether the environment should render its output. This will drastically impact the environment throughput. Only the first environment will be rendered. See :meth:`~torchrl.envs.JumanjiWrapper.render` for more information. @@ -146,17 +168,15 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender): of rewards across steps. device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``torch.device("cpu")``. - batch_size (torch.Size, optional): the batch size of the environment. - With ``jumanji``, this indicates the number of vectorized environments. - Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. + jit (bool, optional): whether the step and reset method should be wrapped in `jit`. + Defaults to ``False``. Attributes: available_envs: environments availalbe to build - Examples: Examples: >>> import jumanji >>> from torchrl.envs import JumanjiWrapper @@ -334,6 +354,7 @@ def __init__( self, env: "jumanji.env.Environment" = None, # noqa: F821 categorical_action_encoding=True, + jit: bool = True, **kwargs, ): if not _has_jumanji: @@ -343,7 +364,26 @@ def __init__( self.categorical_action_encoding = categorical_action_encoding if env is not None: kwargs["env"] = env + batch_locked = kwargs.pop("batch_locked", kwargs.get("batch_size") is not None) super().__init__(**kwargs) + self._batch_locked = batch_locked + self.jit = jit + + @property + def jit(self): + return self._jit + + @jit.setter + def jit(self, value): + self._jit = value + if value: + import jax + + self._env_reset = jax.jit(self._env.reset) + self._env_step = jax.jit(self._env.step) + else: + self._env_reset = self._env.reset + self._env_step = self._env.step def _build_env( self, @@ -486,17 +526,21 @@ def _set_seed(self, seed): raise Exception("Jumanji requires an integer seed.") self.key = jax.random.PRNGKey(seed) - def read_state(self, state): - state_dict = _object_to_tensordict(state, self.device, self.batch_size) + def read_state(self, state, batch_size=None): + state_dict = _object_to_tensordict( + state, self.device, self.batch_size if batch_size is None else batch_size + ) return self.state_spec["state"].encode(state_dict) - def read_obs(self, obs): + def read_obs(self, obs, batch_size=None): from jax import numpy as jnp if isinstance(obs, (list, jnp.ndarray, np.ndarray)): obs_dict = _ndarray_to_tensor(obs).to(self.device) else: - obs_dict = _object_to_tensordict(obs, self.device, self.batch_size) + obs_dict = _object_to_tensordict( + obs, self.device, self.batch_size if batch_size is None else batch_size + ) return super().read_obs(obs_dict) def render( @@ -561,7 +605,11 @@ def render( isinteractive = plt.isinteractive() plt.ion() buf = io.BytesIO() - state = _tensordict_to_object(tensordict.get("state"), _state_example) + state = _tensordict_to_object( + tensordict.get("state"), + _state_example, + batch_size=tensordict.batch_size if not self.batch_locked else None, + ) self._env.render(state, **kwargs) plt.savefig(buf, format="png") buf.seek(0) @@ -580,24 +628,33 @@ def render( def _step(self, tensordict: TensorDictBase) -> TensorDictBase: import jax + if self.batch_locked: + batch_size = self.batch_size + else: + batch_size = tensordict.batch_size + # prepare inputs - state = _tensordict_to_object(tensordict.get("state"), self._state_example) + state = _tensordict_to_object( + tensordict.get("state"), + self._state_example, + batch_size=tensordict.batch_size if not self.batch_locked else None, + ) action = self.read_action(tensordict.get("action")) # flatten batch size into vector - state = _tree_flatten(state, self.batch_size) - action = _tree_flatten(action, self.batch_size) + state = _tree_flatten(state, batch_size) + action = _tree_flatten(action, batch_size) # jax vectorizing map on env.step - state, timestep = jax.vmap(self._env.step)(state, action) + state, timestep = jax.vmap(self._env_step)(state, action) # reshape batch size from vector - state = _tree_reshape(state, self.batch_size) - timestep = _tree_reshape(timestep, self.batch_size) + state = _tree_reshape(state, batch_size) + timestep = _tree_reshape(timestep, batch_size) # collect outputs - state_dict = self.read_state(state) - obs_dict = self.read_obs(timestep.observation) + state_dict = self.read_state(state, batch_size=batch_size) + obs_dict = self.read_obs(timestep.observation, batch_size=batch_size) reward = self.read_reward(np.asarray(timestep.reward)) done = timestep.step_type == self.lib.types.StepType.LAST done = _ndarray_to_tensor(done).view(torch.bool).to(self.device) @@ -622,25 +679,35 @@ def _reset( import jax from jax import numpy as jnp + if self.batch_locked: + numel = self.numel() + batch_size = self.batch_size + else: + numel = tensordict.numel() + batch_size = tensordict.batch_size + # generate random keys - self.key, *keys = jax.random.split(self.key, self.numel() + 1) + self.key, *keys = jax.random.split(self.key, numel + 1) # jax vectorizing map on env.reset - state, timestep = jax.vmap(self._env.reset)(jnp.stack(keys)) + state, timestep = jax.vmap(self._env_reset)(jnp.stack(keys)) # reshape batch size from vector - state = _tree_reshape(state, self.batch_size) - timestep = _tree_reshape(timestep, self.batch_size) + state = _tree_reshape(state, batch_size) + timestep = _tree_reshape(timestep, batch_size) # collect outputs - state_dict = self.read_state(state) - obs_dict = self.read_obs(timestep.observation) - done_td = self.full_done_spec.zero() + state_dict = self.read_state(state, batch_size=batch_size) + obs_dict = self.read_obs(timestep.observation, batch_size=batch_size) + if not self.batch_locked: + done_td = self.full_done_spec.zero(batch_size) + else: + done_td = self.full_done_spec.zero() # build results tensordict_out = TensorDict( source=obs_dict, - batch_size=self.batch_size, + batch_size=batch_size, device=self.device, ) tensordict_out.update(done_td) @@ -648,6 +715,27 @@ def _reset( return tensordict_out + def read_reward(self, reward): + """Reads the reward and maps it to the reward space. + + Args: + reward (torch.Tensor or TensorDict): reward to be mapped. + + """ + if isinstance(reward, int) and reward == 0: + return self.reward_spec.zero() + if self.batch_locked: + reward = self.reward_spec.encode(reward, ignore_device=True) + else: + reward = torch.as_tensor(reward) + if reward.shape[-1] != self.reward_spec.shape[-1]: + reward = reward.unsqueeze(-1) + + if reward is None: + reward = torch.tensor(np.nan).expand(self.reward_spec.shape) + + return reward + def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: ...