diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 29363d3570..47d5cb1bd6 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -366,7 +366,6 @@ class DREAMERVAC(nn.Module): def __init__( self, - obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType, EasyDict], dyn_stoch=32, dyn_deter=512, @@ -391,9 +390,8 @@ def __init__( - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. """ super(DREAMERVAC, self).__init__() - obs_shape: int = squeeze(obs_shape) action_shape = squeeze(action_shape) - self.obs_shape, self.action_shape = obs_shape, action_shape + self.action_shape = action_shape if dyn_discrete: feat_size = dyn_stoch * dyn_discrete + dyn_deter diff --git a/ding/policy/mbpolicy/dreamer.py b/ding/policy/mbpolicy/dreamer.py index 43d3b88619..35287c1bb6 100644 --- a/ding/policy/mbpolicy/dreamer.py +++ b/ding/policy/mbpolicy/dreamer.py @@ -234,8 +234,11 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N latent[key][i] *= mask[i] for i in range(len(action)): action[i] *= mask[i] - - data = data - 0.5 + assert world_model.obs_type == 'vector' or world_model.obs_type == 'RGB', \ + "action type must be vector or RGB" + # normalize RGB image input + if world_model.obs_type == 'RGB': + data = data - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) feat = world_model.dynamics.get_feat(latent) @@ -247,11 +250,18 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N action = action.detach() state = (latent, action) + assert world_model.action_type == 'discrete' or world_model.action_type == 'continuous', \ + "action type must be continuous or discrete" + if world_model.action_type == 'discrete': + action = torch.where(action == 1)[1] output = {"action": action, "logprob": logprob, "state": state} if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) + if world_model.action_type == 'discrete': + for l in range(len(output)): + output[l]['action'] = output[l]['action'].squeeze(0) return {i: d for i, d in zip(data_id, output)} def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: @@ -272,7 +282,7 @@ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple # TODO(zp) random_collect just have action #'logprob': model_output['logprob'], 'reward': timestep.reward, - 'discount': timestep.info['discount'], + 'discount': 1. - timestep.done, # timestep.info['discount'], 'done': timestep.done, } return transition @@ -309,7 +319,9 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict for i in range(len(action)): action[i] *= mask[i] - data = data - 0.5 + # normalize RGB image input + if world_model.obs_type == 'RGB': + data = data - 0.5 embed = world_model.encoder(data) latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) feat = world_model.dynamics.get_feat(latent) @@ -321,11 +333,16 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict action = action.detach() state = (latent, action) + if world_model.action_type == 'discrete': + action = torch.where(action == 1)[1] output = {"action": action, "logprob": logprob, "state": state} if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) + if world_model.action_type == 'discrete': + for l in range(len(output)): + output[l]['action'] = output[l]['action'].squeeze(0) return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: diff --git a/ding/torch_utils/network/dreamer.py b/ding/torch_utils/network/dreamer.py index f7c1597e54..b7ae67b57c 100644 --- a/ding/torch_utils/network/dreamer.py +++ b/ding/torch_utils/network/dreamer.py @@ -178,7 +178,7 @@ def forward(self, features): elif self._dist == "binary": return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) elif self._dist == "twohot_symlog": - return TwoHotDistSymlog(logits=mean, device=self._device) + return TwoHotDistSymlog(logits=mean, low=-1., high=1., device=self._device) raise NotImplementedError(self._dist) @@ -475,8 +475,8 @@ def log_prob(self, x): above = torch.clip(above, 0, len(self.buckets) - 1) equal = (below == above) - dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) - dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) + dist_to_below = torch.where(equal, torch.tensor(1).to(x), torch.abs(self.buckets[below] - x)) + dist_to_above = torch.where(equal, torch.tensor(1).to(x), torch.abs(self.buckets[above] - x)) total = dist_to_below + dist_to_above weight_below = dist_to_above / total weight_above = dist_to_below / total diff --git a/ding/world_model/dreamer.py b/ding/world_model/dreamer.py index eafe257454..ceac6fa082 100644 --- a/ding/world_model/dreamer.py +++ b/ding/world_model/dreamer.py @@ -5,10 +5,10 @@ from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts from ding.utils.data import default_collate -from ding.model import ConvEncoder +from ding.model import ConvEncoder, FCEncoder from ding.world_model.base_world_model import WorldModel from ding.world_model.model.networks import RSSM, ConvDecoder -from ding.torch_utils import to_device +from ding.torch_utils import to_device, one_hot from ding.torch_utils.network.dreamer import DenseHead @@ -37,6 +37,7 @@ class DREAMERWorldModel(WorldModel, nn.Module): norm='LayerNorm', grad_heads=['image', 'reward', 'discount'], units=512, + image_dec_layers=2, reward_layers=2, discount_layers=2, value_layers=2, @@ -71,26 +72,33 @@ def __init__(self, cfg, env, tb_logger): self._cfg.act = nn.modules.activation.SiLU # nn.SiLU self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm self.state_size = self._cfg.state_size + self.obs_type = self._cfg.obs_type self.action_size = self._cfg.action_size + self.action_type = self._cfg.action_type self.reward_size = self._cfg.reward_size self.hidden_size = self._cfg.hidden_size self.batch_size = self._cfg.batch_size + if self.obs_type == 'vector': + self.encoder = FCEncoder(self.state_size, self._cfg.encoder_hidden_size_list, activation=torch.nn.SiLU()) + self.embed_size = self._cfg.encoder_hidden_size_list[-1] + elif self.obs_type == 'RGB': + self.encoder = ConvEncoder( + self.state_size, + hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128? + activation=torch.nn.SiLU(), + kernel_size=self._cfg.encoder_kernels, + layer_norm=True + ) + self.embed_size = ( + (self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth * + 2 ** (len(self._cfg.encoder_kernels) - 1) + ) - self.encoder = ConvEncoder( - self.state_size, - hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128? - activation=torch.nn.SiLU(), - kernel_size=self._cfg.encoder_kernels, - layer_norm=True - ) - self.embed_size = ( - (self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth * - 2 ** (len(self._cfg.encoder_kernels) - 1) - ) self.dynamics = RSSM( self._cfg.dyn_stoch, self._cfg.dyn_deter, self._cfg.dyn_hidden, + self._cfg.action_type, self._cfg.dyn_input_layers, self._cfg.dyn_output_layers, self._cfg.dyn_rec_depth, @@ -113,14 +121,28 @@ def __init__(self, cfg, env, tb_logger): feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter else: feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter - self.heads["image"] = ConvDecoder( - feat_size, # pytorch version - self._cfg.cnn_depth, - self._cfg.act, - self._cfg.norm, - self.state_size, - self._cfg.decoder_kernels, - ) + + if isinstance(self.state_size, int): + self.heads['image'] = DenseHead( + feat_size, + (self.state_size, ), + self._cfg.image_dec_layers, + self._cfg.units, + 'SiLU', # self._cfg.act + 'LN', # self._cfg.norm + dist='binary', + outscale=0.0, + device=self._cfg.device, + ) + elif len(self.state_size) == 3: + self.heads["image"] = ConvDecoder( + feat_size, # pytorch version + self._cfg.cnn_depth, + self._cfg.act, + self._cfg.norm, + self.state_size, + self._cfg.decoder_kernels, + ) self.heads["reward"] = DenseHead( feat_size, # dyn_stoch * dyn_discrete + dyn_deter (255, ), @@ -172,9 +194,15 @@ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length): data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])} data['discount'] = data.get('discount', 1.0 - data['done'].float()) - data['discount'] *= 0.997 data['weight'] = data.get('weight', None) - data['image'] = data['obs'] - 0.5 + if self.obs_type == 'RGB': + data['image'] = data['obs'] - 0.5 + else: + data['image'] = data['obs'] + if self.action_type == 'continuous': + data['action'] *= (1.0 / torch.clip(torch.abs(data['action']), min=1.0)) + else: + data['action'] = one_hot(data['action'], self.action_size) data = to_device(data, self._cfg.device) if len(data['reward'].shape) == 2: data['reward'] = data['reward'].unsqueeze(-1) @@ -185,9 +213,9 @@ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length): self.requires_grad_(requires_grad=True) - image = data['image'].reshape([-1] + list(data['image'].shape[-3:])) + image = data['image'].reshape([-1] + list(data['image'].shape[2:])) embed = self.encoder(image) - embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]]) + embed = embed.reshape(list(data['image'].shape[:2]) + [embed.shape[-1]]) post, prior = self.dynamics.observe(embed, data["action"]) kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( diff --git a/ding/world_model/model/networks.py b/ding/world_model/model/networks.py index 091fa4f827..2e47cc6604 100644 --- a/ding/world_model/model/networks.py +++ b/ding/world_model/model/networks.py @@ -1,11 +1,12 @@ import math import numpy as np +from typing import Optional, Dict, Union, List import torch from torch import nn import torch.nn.functional as F from torch import distributions as torchd - +from ding.utils import SequenceType from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, static_scan, \ OneHotDist, ContDist, SymlogDist, DreamerLayerNorm @@ -17,6 +18,7 @@ def __init__( stoch=30, deter=200, hidden=200, + action_type=None, layers_input=1, layers_output=1, rec_depth=1, @@ -38,6 +40,7 @@ def __init__( self._stoch = stoch self._deter = deter self._hidden = hidden + self._action_type = action_type self._min_std = min_std self._layers_input = layers_input self._layers_output = layers_output @@ -179,7 +182,8 @@ def get_dist(self, state, dtype=None): def obs_step(self, prev_state, prev_action, embed, sample=True): # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs - prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() + if self._action_type == 'continuous': + prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() prior = self.img_step(prev_state, prev_action, None, sample) if self._shared: post = self.img_step(prev_state, prev_action, embed, sample) @@ -202,7 +206,8 @@ def obs_step(self, prev_state, prev_action, embed, sample=True): # this is used for making future image def img_step(self, prev_state, prev_action, embed=None, sample=True): # (batch, stoch, discrete_num) - prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() + if self._action_type == 'continuous': + prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() prev_stoch = prev_state["stoch"] if self._discrete: shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] @@ -282,8 +287,9 @@ def kl_loss(self, post, prior, forward, free, lscale, rscale): dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, dist(rhs) if self._discrete else dist(rhs)._dist, ) - loss_lhs = torch.clip(torch.mean(value_lhs), min=free) - loss_rhs = torch.clip(torch.mean(value_rhs), min=free) + # free bits + loss_lhs = torch.mean(torch.clip(value_lhs, min=free)) + loss_rhs = torch.mean(torch.clip(value_rhs, min=free)) loss = lscale * loss_lhs + rscale * loss_rhs return loss, value, loss_lhs, loss_rhs @@ -357,7 +363,7 @@ def calc_same_pad(self, k, s, d): outpad = pad * 2 - val return pad, outpad - def __call__(self, features, dtype=None): + def __call__(self, features): x = self._linear_layer(features) # feature:[batch, time, stoch*discrete + deter] x = x.reshape([-1, 4, 4, self._embed_size // 16]) x = x.permute(0, 3, 1, 2) diff --git a/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py b/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py index 66f7c7e2a4..623cfaacf1 100644 --- a/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py +++ b/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py @@ -60,7 +60,9 @@ cuda=cuda, model=dict( state_size=(3, 64, 64), # has to be specified + obs_type='RGB', action_size=1, # has to be specified + action_type='continuous', reward_size=1, batch_size=16, ), diff --git a/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py b/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py index 32a43463e7..22b6ae911b 100644 --- a/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py +++ b/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py @@ -60,7 +60,9 @@ cuda=cuda, model=dict( state_size=(3, 64, 64), # has to be specified + obs_type='RGB', action_size=6, # has to be specified + action_type='continuous', reward_size=1, batch_size=16, ), diff --git a/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py b/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py index 16e76eac39..da7f9e0edb 100644 --- a/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py +++ b/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py @@ -28,7 +28,6 @@ # it is better to put random_collect_size in policy.other random_collect_size=2500, model=dict( - obs_shape=(3, 64, 64), action_shape=6, actor_dist='normal', ), @@ -60,7 +59,9 @@ cuda=cuda, model=dict( state_size=(3, 64, 64), # has to be specified + obs_type='RGB', action_size=6, # has to be specified + action_type='continuous', reward_size=1, batch_size=16, ), diff --git a/dizoo/minigrid/config/minigrid_dreamer_config.py b/dizoo/minigrid/config/minigrid_dreamer_config.py new file mode 100644 index 0000000000..410f803d96 --- /dev/null +++ b/dizoo/minigrid/config/minigrid_dreamer_config.py @@ -0,0 +1,96 @@ +from easydict import EasyDict + +from ding.entry import serial_pipeline_dreamer + +cuda = False +collector_env_num = 8 +evaluator_env_num = 5 +minigrid_dreamer_config = dict( + exp_name='minigrid_dreamer_empty', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + # typical MiniGrid env id: + # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, + # please refer to https://github.com/Farama-Foundation/MiniGrid for details. + env_id='MiniGrid-Empty-8x8-v0', + # env_id='MiniGrid-AKTDT-7x7-1-v0', + max_step=100, + stop_value=20, # run fixed env_steps + # stop_value=0.96, + flat_obs=True, + full_obs=True, + onehot_obs=True, + move_bonus=True, + ), + policy=dict( + cuda=cuda, + # it is better to put random_collect_size in policy.other + random_collect_size=2500, + model=dict( + action_shape=7, + # encoder_hidden_size_list=[256, 128, 64, 64], + # critic_head_hidden_size=64, + # actor_head_hidden_size=64, + actor_dist='onehot', + ), + learn=dict( + lambda_=0.95, + learning_rate=3e-5, + batch_size=16, + batch_length=64, + imag_sample=True, + discount=0.997, + reward_EMA=True, + ), + collect=dict( + n_sample=1, + unroll_len=1, + action_size=7, # has to be specified + collect_dyn_sample=True, + ), + eval=dict(evaluator=dict(eval_freq=5000, )), + other=dict( + # environment buffer + replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60), + ), + ), + world_model=dict( + pretrain=100, + train_freq=2, + cuda=cuda, + model=dict( + state_size=1344, + obs_type = 'vector', + action_size=7, + action_type='discrete', + encoder_hidden_size_list=[256, 128, 64, 64], + reward_size=1, + batch_size=16, + ), + ), +) + +minigrid_dreamer_config = EasyDict(minigrid_dreamer_config) + +minigrid_create_config = dict( + env=dict( + type='minigrid', + import_names=['dizoo.minigrid.envs.minigrid_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='dreamer', + import_names=['ding.policy.mbpolicy.dreamer'], + ), + replay_buffer=dict(type='sequence', ), + world_model=dict( + type='dreamer', + import_names=['ding.world_model.dreamer'], + ), +) +minigrid_create_config = EasyDict(minigrid_create_config) + +if __name__ == '__main__': + serial_pipeline_dreamer((minigrid_dreamer_config, minigrid_create_config), seed=0, max_env_step=500000) diff --git a/dizoo/minigrid/envs/minigrid_env.py b/dizoo/minigrid/envs/minigrid_env.py index 12bd64cae0..923967394e 100644 --- a/dizoo/minigrid/envs/minigrid_env.py +++ b/dizoo/minigrid/envs/minigrid_env.py @@ -9,8 +9,8 @@ import numpy as np from matplotlib import animation import matplotlib.pyplot as plt -from minigrid.wrappers import FlatObsWrapper, RGBImgPartialObsWrapper, ImgObsWrapper -from .minigrid_wrapper import ViewSizeWrapper +from minigrid.wrappers import FullyObsWrapper +from .minigrid_wrapper import ViewSizeWrapper, MoveBonus, OneHotObsWrapper, FlatObsWrapper from ding.envs import ObsPlusPrevActRewWrapper from ding.envs import BaseEnv, BaseEnvTimestep @@ -36,6 +36,9 @@ def __init__(self, cfg: dict) -> None: self._init_flag = False self._env_id = cfg.env_id self._flat_obs = cfg.flat_obs + self._full_obs = cfg.full_obs + self._onehot_obs = cfg.onehot_obs + self._move_bonus = cfg.move_bonus self._save_replay = False self._max_step = cfg.max_step @@ -52,6 +55,12 @@ def reset(self) -> np.ndarray: self._env = ViewSizeWrapper(self._env, agent_view_size=5) if self._env_id == 'MiniGrid-AKTDT-7x7-1-v0': self._env = ViewSizeWrapper(self._env, agent_view_size=3) + if self._full_obs: + self._env = FullyObsWrapper(self._env) + if self._onehot_obs: + self._env = OneHotObsWrapper(self._env) + if self._move_bonus: + self._env = MoveBonus(self._env) if self._flat_obs: self._env = FlatObsWrapper(self._env) # self._env = RGBImgPartialObsWrapper(self._env) @@ -60,7 +69,7 @@ def reset(self) -> np.ndarray: self._env = ObsPlusPrevActRewWrapper(self._env) self._init_flag = True if self._flat_obs: - self._observation_space = gym.spaces.Box(0, 1, shape=(2835, ), dtype=np.float32) + self._observation_space = gym.spaces.Box(0, 1, shape=self._env.observation_space.shape, dtype=np.float32) else: self._observation_space = self._env.observation_space # to be compatiable with subprocess env manager diff --git a/dizoo/minigrid/envs/minigrid_wrapper.py b/dizoo/minigrid/envs/minigrid_wrapper.py index 09a14c9c81..72683d159d 100644 --- a/dizoo/minigrid/envs/minigrid_wrapper.py +++ b/dizoo/minigrid/envs/minigrid_wrapper.py @@ -1,6 +1,10 @@ import gymnasium as gym from gymnasium import spaces -from gymnasium.core import ObservationWrapper +from gymnasium.core import ObservationWrapper, Wrapper +import numpy as np +import operator +from functools import reduce +from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX class ViewSizeWrapper(ObservationWrapper): @@ -32,3 +36,161 @@ def observation(self, obs): # print('vis_mask:' + vis_mask) image = grid.encode(vis_mask) return {**obs, "image": image} + + +class MoveBonus(Wrapper): + """ + Adds an movement bonus based on which positions + are visited on the grid. + + Example: + >>> import gymnasium as gym + >>> from minigrid.wrappers import PositionBonus + >>> env = gym.make("MiniGrid-Empty-5x5-v0") + >>> _, _ = env.reset(seed=0) + >>> _, reward, _, _, _ = env.step(1) + >>> print(reward) + 0 + >>> _, reward, _, _, _ = env.step(1) + >>> print(reward) + 0 + >>> env_bonus = MoveBonus(env) + >>> obs, _ = env_bonus.reset(seed=0) + >>> obs, reward, terminated, truncated, info = env_bonus.step(1) + >>> print(reward) + 1.0 + >>> obs, reward, terminated, truncated, info = env_bonus.step(1) + >>> print(reward) + 0.7071067811865475 + """ + + def __init__(self, env): + """A wrapper that adds an exploration bonus to less visited positions. + + Args: + env: The environment to apply the wrapper + """ + super().__init__(env) + self.goal_pos = (self.width - 2, self.height - 2) + self.scale = np.sqrt(self.width ** 2 + self.height ** 2) + + def step(self, action): + """Steps through the environment with `action`.""" + + cur_dis = np.sqrt( + (self.goal_pos[0] - self.env.agent_pos[0]) ** 2 + (self.goal_pos[1] - self.env.agent_pos[1]) ** 2 + ) + obs, reward, terminated, truncated, info = self.env.step(action) + tmp_dis = np.sqrt( + (self.goal_pos[0] - self.env.agent_pos[0]) ** 2 + (self.goal_pos[1] - self.env.agent_pos[1]) ** 2 + ) + + move_bonus = (cur_dis - tmp_dis) / self.scale + reward += move_bonus + + return obs, reward, terminated, truncated, info + + +class OneHotObsWrapper(ObservationWrapper): + """ + Wrapper to get a one-hot encoding of a partially observable + agent view as observation. + + Example: + >>> import gymnasium as gym + >>> from minigrid.wrappers import OneHotPartialObsWrapper + >>> env = gym.make("MiniGrid-Empty-5x5-v0") + >>> obs, _ = env.reset() + >>> obs["image"][0, :, :] + array([[2, 5, 0], + [2, 5, 0], + [2, 5, 0], + [2, 5, 0], + [2, 5, 0], + [2, 5, 0], + [2, 5, 0]], dtype=uint8) + >>> env = OneHotPartialObsWrapper(env) + >>> obs, _ = env.reset() + >>> obs["image"][0, :, :] + array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]], + dtype=uint8) + """ + + def __init__(self, env): + """A wrapper that makes the image observation a one-hot encoding of a partially observable agent view. + + Args: + env: The environment to apply the wrapper + """ + super().__init__(env) + + obs_shape = env.observation_space["image"].shape + + # Number of bits per cell + num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX) + 1 + + new_image_space = spaces.Box(low=0, high=1, shape=(obs_shape[0], obs_shape[1], num_bits), dtype="float32") + self.observation_space = spaces.Dict({**self.observation_space.spaces, "image": new_image_space}) + + def observation(self, obs): + img = obs["image"] + out = np.zeros(self.observation_space.spaces["image"].shape, dtype="float32") + + for i in range(img.shape[0]): + for j in range(img.shape[1]): + type = img[i, j, 0] + color = img[i, j, 1] + state = img[i, j, 2] + + out[i, j, type] = 1 + out[i, j, len(OBJECT_TO_IDX) + color] = 1 + out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1 + + return {**obs, "image": out} + + +class FlatObsWrapper(ObservationWrapper): + """ + Encode mission strings using a one-hot scheme, + and combine these with observed images into one flat array. + + This wrapper is not applicable to BabyAI environments, given that these have their own language component. + + Example: + >>> import gymnasium as gym + >>> import matplotlib.pyplot as plt + >>> from minigrid.wrappers import FlatObsWrapper + >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0") + >>> env_obs = FlatObsWrapper(env) + >>> obs, _ = env_obs.reset() + >>> obs.shape + (2835,) + """ + + def __init__(self, env): + super().__init__(env) + + imgSpace = env.observation_space.spaces["image"] + imgSize = reduce(operator.mul, imgSpace.shape, 1) + + self.observation_space = spaces.Box( + low=0, + high=255, + shape=(imgSize, ), + dtype="float32", + ) + + self.cachedStr: str = None + + def observation(self, obs): + img = obs["image"] + + img = img.flatten() + + return img