From ddb1a88ab44f8d2e3dac50123546b9a128175e26 Mon Sep 17 00:00:00 2001 From: anordin95 Date: Tue, 24 Dec 2024 12:19:11 -0800 Subject: [PATCH] refactor: observation & action for clarity --- README.md | 12 +-- generals/agents/expander_agent.py | 11 +- generals/agents/random_agent.py | 6 +- generals/core/action.py | 68 +++++++++++++ generals/core/game.py | 38 ++++--- generals/core/observation.py | 144 +++++++-------------------- generals/envs/gymnasium_generals.py | 6 +- generals/envs/pettingzoo_generals.py | 4 +- generals/remote/generalsio_client.py | 3 +- 9 files changed, 140 insertions(+), 152 deletions(-) create mode 100644 generals/core/action.py diff --git a/README.md b/README.md index 296663b..919f951 100644 --- a/README.md +++ b/README.md @@ -158,9 +158,7 @@ You can control your replays to your liking! Currently, we support these control ## 🌍 Environment ### 🔭 Observation -An observation for one agent is a dictionary `{"observation": observation, "action_mask": action_mask}`. - -The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N,M)`, or simple `int` constants: +An agents observation contains a broad swath of information about their position in the game. Values are either `numpy` matrices with shape `(N,M)`, or `int` constants: | Key | Shape | Description | | -------------------- | --------- | ---------------------------------------------------------------------------- | | `armies` | `(N,M)` | Number of units in a visible cell regardless of the owner | @@ -179,9 +177,6 @@ The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N | `timestep` | — | Current timestep of the game | | `priority` | — | `1` if your move is evaluted first, `0` otherwise | -The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell -`[i, j]` in one of four directions: `0 (up)`, `1 (down)`, `2 (left)`, or `3 (right)`. - ### ⚡ Action Actions are lists of 5 values `[pass, cell_i, cell_j, direction, split]`, where - `pass` indicates whether you want to `1 (pass)` or `0 (play)`. @@ -190,6 +185,9 @@ Actions are lists of 5 values `[pass, cell_i, cell_j, direction, split]`, where - `direction` indicates whether you want to move `0 (up)`, `1 (down)`, `2 (left)`, or `3 (right)` - `split` indicates whether you want to `1 (split)` units and send only half, or `0 (no split)` where you send all units to the next cell +A convenience function `compute_valid_action_mask` is also provided for detailing the set of legal moves an agent can make based on its `observation`. The `valid_action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell +`[i, j]` in one of four directions: `0 (up)`, `1 (down)`, `2 (left)`, or `3 (right)`. + > [!TIP] > You can see how actions and observations look like by printing a sample form the environment: > ```python @@ -203,7 +201,7 @@ and gives `1` for winner and `-1` for loser, otherwise `0`. ```python def custom_reward_fn(observation, action, done, info): # Give agent a reward based on the number of cells they own - return observation["observation"]["owned_land_count"] + return observation["owned_land_count"] env = gym.make(..., reward_fn=custom_reward_fn) observations, info = env.reset() diff --git a/generals/agents/expander_agent.py b/generals/agents/expander_agent.py index ebd912b..cb500d3 100644 --- a/generals/agents/expander_agent.py +++ b/generals/agents/expander_agent.py @@ -1,7 +1,7 @@ import numpy as np +from generals.core.action import Action, compute_valid_action_mask from generals.core.config import Direction -from generals.core.game import Action from generals.core.observation import Observation from .agent import Agent @@ -16,16 +16,15 @@ def act(self, observation: Observation) -> Action: Heuristically selects a valid (expanding) action. Prioritizes capturing opponent and then neutral cells. """ - mask = observation["action_mask"] - observation = observation["observation"] + mask = compute_valid_action_mask(observation) valid_actions = np.argwhere(mask == 1) if len(valid_actions) == 0: # No valid actions return np.array([1, 0, 0, 0, 0]) - army = observation["armies"] - opponent = observation["opponent_cells"] - neutral = observation["neutral_cells"] + army = observation.armies + opponent = observation.opponent_cells + neutral = observation.neutral_cells # Find actions that capture opponent or neutral cells actions_capture_opponent = np.zeros(len(valid_actions)) diff --git a/generals/agents/random_agent.py b/generals/agents/random_agent.py index 7f80fe6..da16ca6 100644 --- a/generals/agents/random_agent.py +++ b/generals/agents/random_agent.py @@ -1,6 +1,6 @@ import numpy as np -from generals.core.game import Action +from generals.core.action import Action, compute_valid_action_mask from generals.core.observation import Observation from .agent import Agent @@ -23,8 +23,8 @@ def act(self, observation: Observation) -> Action: """ Randomly selects a valid action. """ - mask = observation["action_mask"] - observation = observation["observation"] + + mask = compute_valid_action_mask(observation) valid_actions = np.argwhere(mask == 1) if len(valid_actions) == 0: # No valid actions diff --git a/generals/core/action.py b/generals/core/action.py new file mode 100644 index 0000000..b67623b --- /dev/null +++ b/generals/core/action.py @@ -0,0 +1,68 @@ +from typing import TypeAlias + +import numpy as np + +from generals.core.config import DIRECTIONS + +from .observation import Observation + +""" +Action is intentionally a numpy array rather than a class for the sake of optimization. Granted, +this hasn't been performance tested, so take that decision with a grain of salt. + +The action format is an array with 5 entries: [pass, row, col, direction, split] +Args: + pass: boolean integer (0 or 1) indicating whether the agent should pass/skip this turn and do nothing. + row: The row the agent should move from. In the closed-interval: [0, (grid_height - 1)]. + col: The column the agent should move from. In the closed-interval: [0, (grid_width - 1)]. + direction: An integer indicating which direction to move. 0 (up), 1 (down), 2 (left), 3 (right). + Note: the integer is effecitlvey an index into the DIRECTIONS enum. + split: boolean integer (0 or 1) indicating whether to split the army when moving. +""" +Action: TypeAlias = np.ndarray + + +def compute_valid_action_mask(observation: Observation) -> np.ndarray: + """ + Return a mask of the valid actions for a given observation. + + A valid action is an action that originates from an agent's cell, has + at least 2 units and does not attempt to enter a mountain nor exit the grid. + + Returns: + np.ndarray: an NxNx4 array, where each channel is a boolean mask + of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid. + + I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j). + """ + height, width = observation.owned_cells.shape + + ownership_channel = observation.owned_cells + more_than_1_army = (observation.armies > 1) * ownership_channel + owned_cells_indices = np.argwhere(more_than_1_army) + valid_action_mask = np.zeros((height, width, 4), dtype=bool) + + if np.sum(ownership_channel) == 0: + return valid_action_mask + + for channel_index, direction in enumerate(DIRECTIONS): + destinations = owned_cells_indices + direction.value + + # check if destination is in grid bounds + in_first_boundary = np.all(destinations >= 0, axis=1) + in_height_boundary = destinations[:, 0] < height + in_width_boundary = destinations[:, 1] < width + destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary] + + # check if destination is road + passable_cells = 1 - observation.mountains + # assert that every value is either 0 or 1 in passable cells + assert np.all(np.isin(passable_cells, [0, 1])), f"{passable_cells}" + passable_cell_indices = passable_cells[destinations[:, 0], destinations[:, 1]] == 1 + action_destinations = destinations[passable_cell_indices] + + # get valid action mask for a given direction + valid_source_indices = action_destinations - direction.value + valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0 + + return valid_action_mask diff --git a/generals/core/game.py b/generals/core/game.py index 315d60e..900855c 100644 --- a/generals/core/game.py +++ b/generals/core/game.py @@ -3,13 +3,14 @@ import gymnasium as gym import numpy as np +from .action import Action from .channels import Channels from .config import DIRECTIONS from .grid import Grid from .observation import Observation # Type aliases -Action: TypeAlias = np.ndarray + Info: TypeAlias = dict[str, Any] @@ -41,26 +42,21 @@ def __init__(self, grid: Grid, agents: list[str]): grid_discrete = np.ones(self.grid_dims, dtype=int) * self.max_army_value self.observation_space = gym.spaces.Dict( { - "observation": gym.spaces.Dict( - { - "armies": gym.spaces.MultiDiscrete(grid_discrete), - "generals": grid_multi_binary, - "cities": grid_multi_binary, - "mountains": grid_multi_binary, - "neutral_cells": grid_multi_binary, - "owned_cells": grid_multi_binary, - "opponent_cells": grid_multi_binary, - "fog_cells": grid_multi_binary, - "structures_in_fog": grid_multi_binary, - "owned_land_count": gym.spaces.Discrete(self.max_land_value), - "owned_army_count": gym.spaces.Discrete(self.max_army_value), - "opponent_land_count": gym.spaces.Discrete(self.max_land_value), - "opponent_army_count": gym.spaces.Discrete(self.max_army_value), - "timestep": gym.spaces.Discrete(self.max_timestep), - "priority": gym.spaces.Discrete(2), - } - ), - "action_mask": gym.spaces.MultiBinary(self.grid_dims + (4,)), + "armies": gym.spaces.MultiDiscrete(grid_discrete), + "generals": grid_multi_binary, + "cities": grid_multi_binary, + "mountains": grid_multi_binary, + "neutral_cells": grid_multi_binary, + "owned_cells": grid_multi_binary, + "opponent_cells": grid_multi_binary, + "fog_cells": grid_multi_binary, + "structures_in_fog": grid_multi_binary, + "owned_land_count": gym.spaces.Discrete(self.max_land_value), + "owned_army_count": gym.spaces.Discrete(self.max_army_value), + "opponent_land_count": gym.spaces.Discrete(self.max_land_value), + "opponent_army_count": gym.spaces.Discrete(self.max_army_value), + "timestep": gym.spaces.Discrete(self.max_timestep), + "priority": gym.spaces.Discrete(2), } ) diff --git a/generals/core/observation.py b/generals/core/observation.py index b7aff2d..ca97e47 100644 --- a/generals/core/observation.py +++ b/generals/core/observation.py @@ -1,111 +1,39 @@ -import numpy as np - -from generals.core.config import DIRECTIONS - - -class Observation: - def __init__( - self, - armies: np.ndarray, - generals: np.ndarray, - cities: np.ndarray, - mountains: np.ndarray, - neutral_cells: np.ndarray, - owned_cells: np.ndarray, - opponent_cells: np.ndarray, - fog_cells: np.ndarray, - structures_in_fog: np.ndarray, - owned_land_count: int, - owned_army_count: int, - opponent_land_count: int, - opponent_army_count: int, - timestep: int, - priority: int = 0, - ): - self.armies = armies - self.generals = generals - self.cities = cities - self.mountains = mountains - self.neutral_cells = neutral_cells - self.owned_cells = owned_cells - self.opponent_cells = opponent_cells - self.fog_cells = fog_cells - self.structures_in_fog = structures_in_fog - self.owned_land_count = owned_land_count - self.owned_army_count = owned_army_count - self.opponent_land_count = opponent_land_count - self.opponent_army_count = opponent_army_count - self.timestep = timestep - self.priority = priority - # armies, generals, cities, mountains, empty, owner, fogged, structure in fog - - def action_mask(self) -> np.ndarray: - """ - Function to compute valid actions from a given ownership mask. - - Valid action is an action that originates from agent's cell with atleast 2 units - and does not bump into a mountain or fall out of the grid. - Returns: - np.ndarray: an NxNx4 array, where each channel is a boolean mask - of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid. - - I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j). - """ - height, width = self.owned_cells.shape +import dataclasses - ownership_channel = self.owned_cells - more_than_1_army = (self.armies > 1) * ownership_channel - owned_cells_indices = np.argwhere(more_than_1_army) - valid_action_mask = np.zeros((height, width, 4), dtype=bool) - - if np.sum(ownership_channel) == 0: - return valid_action_mask - - for channel_index, direction in enumerate(DIRECTIONS): - destinations = owned_cells_indices + direction.value - - # check if destination is in grid bounds - in_first_boundary = np.all(destinations >= 0, axis=1) - in_height_boundary = destinations[:, 0] < height - in_width_boundary = destinations[:, 1] < width - destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary] - - # check if destination is road - passable_cells = 1 - self.mountains - # assert that every value is either 0 or 1 in passable cells - assert np.all(np.isin(passable_cells, [0, 1])), f"{passable_cells}" - passable_cell_indices = passable_cells[destinations[:, 0], destinations[:, 1]] == 1 - action_destinations = destinations[passable_cell_indices] - - # get valid action mask for a given direction - valid_source_indices = action_destinations - direction.value - valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0 +import numpy as np - return valid_action_mask - def as_dict(self, with_mask=True): - _obs = { - "armies": self.armies, - "generals": self.generals, - "cities": self.cities, - "mountains": self.mountains, - "neutral_cells": self.neutral_cells, - "owned_cells": self.owned_cells, - "opponent_cells": self.opponent_cells, - "fog_cells": self.fog_cells, - "structures_in_fog": self.structures_in_fog, - "owned_land_count": self.owned_land_count, - "owned_army_count": self.owned_army_count, - "opponent_land_count": self.opponent_land_count, - "opponent_army_count": self.opponent_army_count, - "timestep": self.timestep, - "priority": self.priority, - } - if with_mask: - obs = { - "observation": _obs, - "action_mask": self.action_mask(), - } - else: - obs = _obs - return obs +@dataclasses.dataclass +class Observation(dict): + """ + We override some dictionary methods and subclass dict to allow the + Observation object to be accessible in dictionary-style format, + e.g. observation["armies"]. And to allow for providing a + listing of the keys/attributes. + + These steps are necessary because PettingZoo & Gymnasium expect + dictionary-like Observation objects, but we want the benefits of + knowing the dictionaries' members which a dataclass/class provides. + """ + + armies: np.ndarray + generals: np.ndarray + cities: np.ndarray + mountains: np.ndarray + neutral_cells: np.ndarray + owned_cells: np.ndarray + opponent_cells: np.ndarray + fog_cells: np.ndarray + structures_in_fog: np.ndarray + owned_land_count: int + owned_army_count: int + opponent_land_count: int + opponent_army_count: int + timestep: int + priority: int = 0 + + def __getitem__(self, attribute_name: str): + return getattr(self, attribute_name) + + def keys(self): + return dataclasses.asdict(self).keys() diff --git a/generals/envs/gymnasium_generals.py b/generals/envs/gymnasium_generals.py index 964bc70..5d1c847 100644 --- a/generals/envs/gymnasium_generals.py +++ b/generals/envs/gymnasium_generals.py @@ -96,20 +96,20 @@ def reset( self.observation_space = self.game.observation_space self.action_space = self.game.action_space - observation = self.game.agent_observation(self.agent_id).as_dict() + observation = self.game.agent_observation(self.agent_id) info: dict[str, Any] = {} return observation, info def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: # Get action of NPC - npc_observation = self.game.agent_observation(self.npc.id).as_dict() + npc_observation = self.game.agent_observation(self.npc.id) npc_action = self.npc.act(npc_observation) actions = {self.agent_id: action, self.npc.id: npc_action} observations, infos = self.game.step(actions) # From observations of all agents, pick only those relevant for the main agent - obs = observations[self.agent_id].as_dict() + obs = observations[self.agent_id] info = infos[self.agent_id] reward = self.reward_fn(obs, action, self.game.is_done(), info) terminated = self.game.is_done() diff --git a/generals/envs/pettingzoo_generals.py b/generals/envs/pettingzoo_generals.py index 6e2ce8d..628c7a4 100644 --- a/generals/envs/pettingzoo_generals.py +++ b/generals/envs/pettingzoo_generals.py @@ -105,7 +105,7 @@ def reset( elif hasattr(self, "replay"): del self.replay - observations = {agent: self.game.agent_observation(agent).as_dict() for agent in self.agents} + observations = {agent: self.game.agent_observation(agent) for agent in self.agents} infos: dict[str, Any] = {agent: {} for agent in self.agents} return observations, infos @@ -119,7 +119,7 @@ def step( dict[AgentID, Info], ]: observations, infos = self.game.step(actions) - observations = {agent: observation.as_dict() for agent, observation in observations.items()} + observations = {agent: observation for agent, observation in observations.items()} # You probably want to set your truncation based on self.game.time truncation = False if self.truncation is None else self.game.time >= self.truncation truncated = {agent: truncation for agent in self.agents} diff --git a/generals/remote/generalsio_client.py b/generals/remote/generalsio_client.py index f949edd..780ce40 100644 --- a/generals/remote/generalsio_client.py +++ b/generals/remote/generalsio_client.py @@ -125,8 +125,7 @@ def _generate_action(self, observation: Observation) -> tuple[int, int, int] | N Translate action from Agent to the server format. :param action: dictionary representing the action """ - obs = observation.as_dict() - action = self.agent.act(obs) + action = self.agent.act(observation) if not action["pass"]: source: np.ndarray = np.array(action["cell"]) direction = np.array(DIRECTIONS[action["direction"]].value)