Skip to content

Commit

Permalink
feat: Add new better wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Nov 22, 2024
1 parent 8b9676e commit 96b27b4
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 117 deletions.
8 changes: 4 additions & 4 deletions generals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def _register_gym_generals_envs():
)

register(
id="gym-generals-normalized-v0",
entry_point="generals.envs.initializers:gyms_generals_normalized_v0",
id="gym-generals-image-v0",
entry_point="generals.envs.initializers:gym_image_observations",
)

register(
id="gym-generals-image-v0",
entry_point="generals.envs.initializers:gym_image_observations",
id="gym-generals-rllib-v0",
entry_point="generals.envs.initializers:gym_rllib",
)


Expand Down
12 changes: 9 additions & 3 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, grid: Grid, agents: list[str]):
self.increment_rate = 50

# Limits
self.max_army_value = 10_000
self.max_army_value = 100_000
self.max_land_value = np.prod(self.grid_dims)
self.max_timestep = 100_000

Expand All @@ -52,9 +52,9 @@ def __init__(self, grid: Grid, agents: list[str]):
"opponent_cells": grid_multi_binary,
"fog_cells": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": gym.spaces.Discrete(self.max_army_value),
"owned_land_count": gym.spaces.Discrete(self.max_land_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_land_value),
"opponent_army_count": gym.spaces.Discrete(self.max_army_value),
"timestep": gym.spaces.Discrete(self.max_timestep),
"priority": gym.spaces.Discrete(2),
Expand Down Expand Up @@ -119,6 +119,12 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
sj + DIRECTIONS[direction].value[1],
) # destination indices

# Skip if the destination cell is not passable or out of bounds
if di < 0 or di >= self.grid_dims[0] or dj < 0 or dj >= self.grid_dims[1]:
continue
if self.channels.passable[di, dj] == 0:
continue

# Figure out the target square owner and army size
target_square_army = self.channels.armies[di, dj]
target_square_owner_idx = np.argmax(
Expand Down
136 changes: 34 additions & 102 deletions generals/envs/gymnasium_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,121 +2,53 @@
import numpy as np


class NormalizedObservationWrapper(gym.ObservationWrapper):
class RemoveActionMaskWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
grid_multi_binary = gym.spaces.MultiBinary(self.game.grid_dims)
unit_box = gym.spaces.Box(low=0, high=1, dtype=np.float32)
self.observation_space = gym.spaces.Dict(
{
"observation": gym.spaces.Dict(
{
"army": gym.spaces.Box(low=0, high=1, shape=self.game.grid_dims, dtype=np.float32),
"general": grid_multi_binary,
"city": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": unit_box,
"owned_army_count": unit_box,
"opponent_land_count": unit_box,
"opponent_army_count": unit_box,
"is_winner": gym.spaces.Discrete(2),
"timestep": unit_box,
}
),
"action_mask": gym.spaces.MultiBinary(self.game.grid_dims + (4,)),
}
)
self.observation_space = env.observation_space["observation"]

def observation(self, observation):
game = self.game
_observation = observation["observation"] if "observation" in observation else observation
_observation["army"] = np.array(_observation["army"] / game.max_army_value, dtype=np.float32)
_observation["timestep"] = np.array([_observation["timestep"] / game.max_timestep], dtype=np.float32)
_observation["owned_land_count"] = np.array(
[_observation["owned_land_count"] / game.max_land_value], dtype=np.float32
)
_observation["opponent_land_count"] = np.array(
[_observation["opponent_land_count"] / game.max_land_value],
dtype=np.float32,
)
_observation["owned_army_count"] = np.array(
[_observation["owned_army_count"] / game.max_army_value], dtype=np.float32
)
_observation["opponent_army_count"] = np.array(
[_observation["opponent_army_count"] / game.max_army_value],
dtype=np.float32,
)
observation["observation"] = _observation
return observation
return _observation


class RemoveActionMaskWrapper(gym.ObservationWrapper):
class ObservationAsImageWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
grid_multi_binary = gym.spaces.MultiBinary(self.grid_dims)
grid_discrete = np.ones(self.grid_dims, dtype=int) * self.max_army_value
n_obs_keys = len(self.observation_space["observation"].items())
self.observation_space = gym.spaces.Dict(
{
"armies": gym.spaces.MultiDiscrete(grid_discrete),
"generals": grid_multi_binary,
"cities": grid_multi_binary,
"mountains": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"fog_cells": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": gym.spaces.Discrete(self.max_army_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_army_value),
"opponent_army_count": gym.spaces.Discrete(self.max_army_value),
"timestep": gym.spaces.Discrete(self.max_timestep),
"priority": gym.spaces.Discrete(2),
"observation": gym.spaces.Box(
low=0, high=1, shape=self.game.grid_dims + (n_obs_keys,), dtype=np.float32
),
"action_mask": gym.spaces.MultiBinary(self.game.grid_dims + (4,)),
}
)

def observation(self, observation):
_observation = observation["observation"] if "observation" in observation else observation
return _observation


class ObservationAsImageWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=self.game.grid_dims + (14,), dtype=np.float32)

def observation(self, observation):
_observation = observation["observation"] if "observation" in observation else observation
# broadcast owned_land_count and other unit_boxes to the shape of the grid
_owned_land_count = np.broadcast_to(_observation["owned_land_count"], self.game.grid_dims)
_owned_army_count = np.broadcast_to(_observation["owned_army_count"], self.game.grid_dims)
_opponent_land_count = np.broadcast_to(_observation["opponent_land_count"], self.game.grid_dims)
_opponent_army_count = np.broadcast_to(_observation["opponent_army_count"], self.game.grid_dims)
_is_winner = np.broadcast_to(_observation["is_winner"], self.game.grid_dims)
_timestep = np.broadcast_to(_observation["timestep"], self.game.grid_dims)
_observation = np.stack(
[
_observation["army"],
_observation["general"],
_observation["city"],
_observation["owned_cells"],
_observation["opponent_cells"],
_observation["neutral_cells"],
_observation["visible_cells"],
_observation["structures_in_fog"],
_owned_land_count,
_owned_army_count,
_opponent_land_count,
_opponent_army_count,
_is_winner,
_timestep,
],
dtype=np.float32,
axis=-1,
game = self.game
_obs = observation["observation"] if "observation" in observation else observation
_obs = (
np.stack(
[
_obs["armies"] / game.max_army_value,
_obs["generals"],
_obs["cities"],
_obs["mountains"],
_obs["neutral_cells"],
_obs["owned_cells"],
_obs["opponent_cells"],
_obs["fog_cells"],
_obs["structures_in_fog"],
np.ones(game.grid_dims) * _obs["owned_land_count"] / game.max_land_value,
np.ones(game.grid_dims) * _obs["owned_army_count"] / game.max_army_value,
np.ones(game.grid_dims) * _obs["opponent_land_count"] / game.max_land_value,
np.ones(game.grid_dims) * _obs["opponent_army_count"] / game.max_army_value,
np.ones(game.grid_dims) * _obs["timestep"] / game.max_timestep,
np.ones(game.grid_dims) * _obs["priority"],
]
)
.astype(np.float32)
.transpose(1, 2, 0)
)
_observation = np.moveaxis(_observation, -1, 0)
return _observation
return _obs
18 changes: 10 additions & 8 deletions generals/envs/initializers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from generals import GridFactory
from generals.agents import Agent
from generals.envs.gymnasium_generals import GymnasiumGenerals, RewardFn
from generals.envs.gymnasium_wrappers import NormalizedObservationWrapper, ObservationAsImageWrapper
from generals.envs.gymnasium_wrappers import ObservationAsImageWrapper, RemoveActionMaskWrapper

"""
Here we can define environment initialization functions that
Expand All @@ -15,7 +15,7 @@
"""


def gym_generals_normalized_v0(
def gym_image_observations(
grid_factory: GridFactory | None = None,
npc: Agent | None = None,
agent: Agent | None = None,
Expand All @@ -24,7 +24,7 @@ def gym_generals_normalized_v0(
):
"""
Example of a Gymnasium environment initializer that creates
an environment that returns normalized observations.
an environment that returns image observations.
"""
_env = GymnasiumGenerals(
grid_factory=grid_factory,
Expand All @@ -33,10 +33,11 @@ def gym_generals_normalized_v0(
render_mode=render_mode,
reward_fn=reward_fn,
)
env = NormalizedObservationWrapper(_env)
env = ObservationAsImageWrapper(_env)
return env

def gym_image_observations(

def gym_rllib(
grid_factory: GridFactory | None = None,
npc: Agent | None = None,
agent: Agent | None = None,
Expand All @@ -47,12 +48,13 @@ def gym_image_observations(
Example of a Gymnasium environment initializer that creates
an environment that returns image observations.
"""
_env = GymnasiumGenerals(
env = GymnasiumGenerals(
grid_factory=grid_factory,
npc=npc,
agent=agent,
render_mode=render_mode,
reward_fn=reward_fn,
)
env = ObservationAsImageWrapper(_env)
return env
image_env = ObservationAsImageWrapper(env)
no_action_env = RemoveActionMaskWrapper(image_env)
return no_action_env

0 comments on commit 96b27b4

Please sign in to comment.