Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/actions as obs #291

Merged
merged 14 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(
fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs
)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.stack(actions, -1).cpu().numpy()
Expand Down Expand Up @@ -304,7 +306,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.inference_mode():
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
Expand Down
10 changes: 6 additions & 4 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict
from typing import Any, Dict, Sequence

import numpy as np
import torch
Expand All @@ -13,8 +13,10 @@
AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v.copy()).to(fabric.device).float().reshape(num_envs, -1) for k, v in obs.items()}
def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, mlp_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(obs[k].copy()).to(fabric.device).float().reshape(num_envs, -1) for k in mlp_keys}
return torch_obs


Expand All @@ -28,7 +30,7 @@ def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):

while not done:
# Convert observations to tensors
torch_obs = prepare_obs(fabric, obs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder)

# Act greedly through the environment
actions = agent.get_actions(torch_obs, greedy=True)
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
else:
with torch.inference_mode():
# Sample an action given the observation received by the environment
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = player(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
else:
# Sample an action given the observation received by the environment
with torch.inference_mode():
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = player(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def player(
actions = envs.action_space.sample()
else:
# Sample an action given the observation received by the environment
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = actor(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
MODELS_TO_REGISTER = {"agent"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Tensor:
def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, mlp_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Tensor:
with fabric.device:
torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in obs.keys()], dim=-1)
torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in mlp_keys], dim=-1)
return torch_obs.reshape(num_envs, -1)


Expand All @@ -43,7 +45,7 @@ def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
obs = env.reset(seed=cfg.seed)[0]
while not done:
# Act greedly through the environment
torch_obs = prepare_obs(fabric, obs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder)
action = actor.get_actions(torch_obs, greedy=True)

# Single environment step
Expand Down
2 changes: 2 additions & 0 deletions sheeprl/configs/env/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ grayscale: False
clip_rewards: False
capture_video: True
frame_stack_dilation: 1
action_stack: -1
action_stack_dilation: 1
max_episode_steps: null
reward_as_observation: False
wrapper: ???
82 changes: 82 additions & 0 deletions sheeprl/envs/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import time
from collections import deque
Expand Down Expand Up @@ -251,3 +253,83 @@ def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
if len(frame.shape) == 3 and frame.shape[-1] == 1:
frame = frame.repeat(3, axis=-1)
return frame


class ActionsAsObservationWrapper(gym.Wrapper):
def __init__(self, env: Env, num_stack: int, noop: float | int | List[int], dilation: int = 1):
super().__init__(env)
if num_stack < 1:
raise ValueError(
"The number of actions to the `action_stack` observation "
f"must be greater or equal than 1, got: {num_stack}"
)
if dilation < 1:
raise ValueError(f"The actions stack dilation argument must be greater than zero, got: {dilation}")
self._num_stack = num_stack
self._dilation = dilation
self._actions = deque(maxlen=num_stack * dilation)
self._is_continuous = isinstance(self.env.action_space, gym.spaces.Box)
self._is_multidiscrete = isinstance(self.env.action_space, gym.spaces.MultiDiscrete)
self.observation_space = copy.deepcopy(self.env.observation_space)
if self._is_continuous:
self._action_shape = self.env.action_space.shape[0]
low = np.resize(self.env.action_space.low, self._action_shape * num_stack)
high = np.resize(self.env.action_space.high, self._action_shape * num_stack)
elif self._is_multidiscrete:
low = 0
high = 1 # one-hot encoding
# one one-hot for each action
self._action_shape = sum(self.env.action_space.nvec)
else:
low = 0
high = 1 # one-hot encoding
self._action_shape = self.env.action_space.n
self.observation_space["action_stack"] = gym.spaces.Box(
low=low, high=high, shape=(self._action_shape * num_stack,), dtype=np.float32
)
if self._is_continuous:
if isinstance(noop, list):
raise ValueError(f"The noop actions must be a float for continuous action spaces, got: {noop}")
self.noop = np.full((self._action_shape,), noop, dtype=np.float32)
elif self._is_multidiscrete:
if not isinstance(noop, list):
raise ValueError(f"The noop actions must be a list for multi-discrete action spaces, got: {noop}")
noops = []
for act, n in zip(noop, self.env.action_space.nvec):
noops.append(np.zeros((n,), dtype=np.float32))
noops[-1][noop[act]] = 1.0
self.noop = np.concatenate(noops, axis=-1)
else:
if isinstance(noop, (list, float)):
raise ValueError(f"The noop actions must be an integer for discrete action spaces, got: {noop}")
self.noop = np.zeros((self._action_shape,), dtype=np.float32)
self.noop[noop] = 1.0

def step(self, action: Any) -> Tuple[Any | SupportsFloat | bool | Dict[str, Any]]:
if self._is_continuous:
self._actions.append(action)
elif self._is_multidiscrete:
one_hot_actions = []
for act, n in zip(action, self.env.action_space.nvec):
one_hot_actions.append(np.zeros((n,), dtype=np.float32))
one_hot_actions[-1][act] = 1.0
self._actions.append(np.concatenate(one_hot_actions, axis=-1))
else:
one_hot_action = np.zeros((self._action_shape,), dtype=np.float32)
one_hot_action[action] = 1.0
self._actions.append(one_hot_action)
obs, reward, done, truncated, info = super().step(action)
obs["action_stack"] = self._get_actions_stack()
return obs, reward, done, truncated, info

def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None) -> Tuple[Any | Dict[str, Any]]:
obs, info = super().reset(seed=seed, options=options)
self._actions.clear()
[self._actions.append(self.noop) for _ in range(self._num_stack * self._dilation)]
obs["action_stack"] = self._get_actions_stack()
return obs, info

def _get_actions_stack(self) -> np.ndarray:
actions_stack = list(self._actions)[self._dilation - 1 :: self._dilation]
actions = np.concatenate(actions_stack, axis=-1)
return actions.astype(np.float32)
4 changes: 4 additions & 0 deletions sheeprl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sheeprl.envs.wrappers import (
ActionRepeat,
ActionsAsObservationWrapper,
FrameStack,
GrayscaleRenderWrapper,
MaskVelocityWrapper,
Expand Down Expand Up @@ -207,6 +208,9 @@ def transform_obs(obs: Dict[str, Any]):
)
env = FrameStack(env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation)

if cfg.env.action_stack > 0 and "diambra" not in cfg.env.wrapper._target_:
env = ActionsAsObservationWrapper(env, cfg.env.action_stack, cfg.env.action_stack_dilation)
michele-milesi marked this conversation as resolved.
Show resolved Hide resolved

if cfg.env.reward_as_observation:
env = RewardAsObservationWrapper(env)

Expand Down
32 changes: 31 additions & 1 deletion tests/test_envs/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
import gymnasium as gym
import pytest

from sheeprl.envs.wrappers import MaskVelocityWrapper
from sheeprl.envs.dummy import ContinuousDummyEnv, DiscreteDummyEnv, MultiDiscreteDummyEnv
from sheeprl.envs.wrappers import ActionsAsObservationWrapper, MaskVelocityWrapper

ENVIRONMENTS = {
"discrete_dummy": DiscreteDummyEnv,
"multidiscrete_dummy": MultiDiscreteDummyEnv,
"continuous_dummy": ContinuousDummyEnv,
}


def test_mask_velocities_fail():
with pytest.raises(NotImplementedError):
env = gym.make("CarRacing-v2")
env = MaskVelocityWrapper(env)


@pytest.mark.parametrize("num_stack", [1, 4, 8])
@pytest.mark.parametrize("dilation", [1, 2, 4])
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"])
def test_actions_as_observations_wrapper(env_id: str, num_stack, dilation):
env = ENVIRONMENTS[env_id]()
if isinstance(env.action_space, gym.spaces.MultiDiscrete):
noop = [0, 0]
else:
noop = 0
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=dilation)

o = env.reset()[0]
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape)
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape):
assert d1 == d2

for _ in range(64):
o = env.step(env.action_space.sample())[0]
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape)
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape):
assert d1 == d2
Loading