From b58e4fc8f6e677086a0fb7ef2714df791feb9959 Mon Sep 17 00:00:00 2001 From: Matej Straka Date: Wed, 23 Oct 2024 20:37:22 +0200 Subject: [PATCH] refactor: Make Observation a standalone object .. and make it compatible with observations from simulator and generalsio --- generals/agents/expander_agent.py | 2 +- generals/core/channels.py | 7 + generals/core/game.py | 170 ++++++++++------------- generals/core/observation.py | 193 +++++++++++++++++++++++++++ generals/envs/gymnasium_generals.py | 15 ++- generals/gui/rendering.py | 12 +- generals/remote/generalsio_client.py | 47 +++++-- 7 files changed, 329 insertions(+), 117 deletions(-) create mode 100644 generals/core/observation.py diff --git a/generals/agents/expander_agent.py b/generals/agents/expander_agent.py index ca3550b..b7cfbec 100644 --- a/generals/agents/expander_agent.py +++ b/generals/agents/expander_agent.py @@ -26,7 +26,7 @@ def act(self, observation: Observation) -> Action: "split": 0, } - army = observation["army"] + army = observation["armies"] opponent = observation["opponent_cells"] neutral = observation["neutral_cells"] diff --git a/generals/core/channels.py b/generals/core/channels.py index b56dfbd..ce720d4 100644 --- a/generals/core/channels.py +++ b/generals/core/channels.py @@ -39,6 +39,13 @@ def get_visibility(self, agent_id: str) -> np.ndarray: channel = self._ownership[agent_id] return maximum_filter(channel, size=3) + @staticmethod + def channel_to_indices(channel: np.ndarray) -> np.ndarray: + """ + Returns a list of indices of cells with non-zero values from specified a channel. + """ + return np.argwhere(channel != 0) + @property def ownership(self) -> dict[str, np.ndarray]: return self._ownership diff --git a/generals/core/game.py b/generals/core/game.py index 12e5164..0330945 100644 --- a/generals/core/game.py +++ b/generals/core/game.py @@ -40,9 +40,9 @@ def __init__(self, grid: Grid, agents: list[str]): { "observation": gym.spaces.Dict( { - "army": gym.spaces.MultiDiscrete(grid_discrete), - "general": grid_multi_binary, - "city": grid_multi_binary, + "armies": gym.spaces.MultiDiscrete(grid_discrete), + "generals": grid_multi_binary, + "cities": grid_multi_binary, "owned_cells": grid_multi_binary, "opponent_cells": grid_multi_binary, "neutral_cells": grid_multi_binary, @@ -52,7 +52,6 @@ def __init__(self, grid: Grid, agents: list[str]): "owned_army_count": gym.spaces.Discrete(self.max_army_value), "opponent_land_count": gym.spaces.Discrete(self.max_army_value), "opponent_army_count": gym.spaces.Discrete(self.max_army_value), - "is_winner": gym.spaces.Discrete(2), "timestep": gym.spaces.Discrete(self.max_timestep), } ), @@ -69,51 +68,46 @@ def __init__(self, grid: Grid, agents: list[str]): } ) - def action_mask(self, agent: str) -> np.ndarray: - """ - Function to compute valid actions from a given ownership mask. - - Valid action is an action that originates from agent's cell with atleast 2 units - and does not bump into a mountain or fall out of the grid. - Returns: - np.ndarray: an NxNx4 array, where each channel is a boolean mask - of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid. - - I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j). - """ - - ownership_channel = self.channels.ownership[agent] - more_than_1_army = (self.channels.army > 1) * ownership_channel - owned_cells_indices = self.channel_to_indices(more_than_1_army) - valid_action_mask = np.zeros((self.grid_dims[0], self.grid_dims[1], 4), dtype=bool) - - if self.is_done() and not self.agent_won(agent): # if you lost, return all zeros - return valid_action_mask + # def action_mask(self, agent: str) -> np.ndarray: + # """ + # Function to compute valid actions from a given ownership mask. + # + # Valid action is an action that originates from agent's cell with atleast 2 units + # and does not bump into a mountain or fall out of the grid. + # Returns: + # np.ndarray: an NxNx4 array, where each channel is a boolean mask + # of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid. + # + # I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j). + # """ + # + # ownership_channel = self.channels.ownership[agent] + # more_than_1_army = (self.channels.army > 1) * ownership_channel + # owned_cells_indices = self.channel_to_indices(more_than_1_army) + # valid_action_mask = np.zeros((self.grid_dims[0], self.grid_dims[1], 4), dtype=bool) + # + # if self.is_done() and not self.agent_won(agent): # if you lost, return all zeros + # return valid_action_mask + # + # for channel_index, direction in enumerate(DIRECTIONS): + # destinations = owned_cells_indices + direction.value + # + # # check if destination is in grid bounds + # in_first_boundary = np.all(destinations >= 0, axis=1) + # in_height_boundary = destinations[:, 0] < self.grid_dims[0] + # in_width_boundary = destinations[:, 1] < self.grid_dims[1] + # destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary] + # + # # check if destination is road + # passable_cell_indices = self.channels.passable[destinations[:, 0], destinations[:, 1]] == 1 + # action_destinations = destinations[passable_cell_indices] + # + # # get valid action mask for a given direction + # valid_source_indices = action_destinations - direction.value + # valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0 + # # assert False + # return valid_action_mask - for channel_index, direction in enumerate(DIRECTIONS): - destinations = owned_cells_indices + direction.value - - # check if destination is in grid bounds - in_first_boundary = np.all(destinations >= 0, axis=1) - in_height_boundary = destinations[:, 0] < self.grid_dims[0] - in_width_boundary = destinations[:, 1] < self.grid_dims[1] - destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary] - - # check if destination is road - passable_cell_indices = self.channels.passable[destinations[:, 0], destinations[:, 1]] == 1 - action_destinations = destinations[passable_cell_indices] - - # get valid action mask for a given direction - valid_source_indices = action_destinations - direction.value - valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0 - # assert False - return valid_action_mask - - def channel_to_indices(self, channel: np.ndarray) -> np.ndarray: - """ - Returns a list of indices of cells with non-zero values from specified a channel. - """ - return np.argwhere(channel != 0) def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict[str, Any]]: """ @@ -135,14 +129,6 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict # Skip if agent wants to pass the turn if pass_turn == 1: continue - # Skip if the move is invalid - if self.action_mask(agent)[i, j, direction] == 0: - warnings.warn( - f"The submitted move byt agent {agent} does not take effect.\ - Probably because you submitted an invalid move.", - UserWarning, - ) - continue if split_army == 1: # Agent wants to split the army army_to_move = self.channels.army[i, j] // 2 else: # Leave just one army in the source cell @@ -199,16 +185,6 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict else: self._global_game_update() - observations = self.get_all_observations() - infos: dict[str, Any] = {agent: {} for agent in self.agents} - return observations, infos - - def get_all_observations(self) -> dict[str, Observation]: - """ - Returns observations for all agents. - """ - return {agent: self.agent_observation(agent) for agent in self.agents} - def _global_game_update(self) -> None: """ Update game state globally. @@ -250,37 +226,37 @@ def get_infos(self) -> dict[str, Info]: "is_winner": self.agent_won(agent), } return players_stats - - def agent_observation(self, agent: str) -> Observation: - """ - Returns an observation for a given agent. - """ - info = self.get_infos() - opponent = self.agents[0] if agent == self.agents[1] else self.agents[1] - visible = self.channels.get_visibility(agent) - invisible = 1 - visible - _observation = { - "army": self.channels.army.astype(int) * visible, - "general": self.channels.general * visible, - "city": self.channels.city * visible, - "owned_cells": self.channels.ownership[agent] * visible, - "opponent_cells": self.channels.ownership[opponent] * visible, - "neutral_cells": self.channels.ownership_neutral * visible, - "visible_cells": visible, - "structures_in_fog": invisible * (self.channels.mountain + self.channels.city), - "owned_land_count": info[agent]["land"], - "owned_army_count": info[agent]["army"], - "opponent_land_count": info[opponent]["land"], - "opponent_army_count": info[opponent]["army"], - "is_winner": int(info[agent]["is_winner"]), - "timestep": self.time, - } - observation: Observation = { - "observation": _observation, - "action_mask": self.action_mask(agent), - } - - return observation + # + # def agent_observation(self, agent: str) -> Observation: + # """ + # Returns an observation for a given agent. + # """ + # info = self.get_infos() + # opponent = self.agents[0] if agent == self.agents[1] else self.agents[1] + # visible = self.channels.get_visibility(agent) + # invisible = 1 - visible + # _observation = { + # "army": self.channels.army.astype(int) * visible, + # "general": self.channels.general * visible, + # "city": self.channels.city * visible, + # "owned_cells": self.channels.ownership[agent] * visible, + # "opponent_cells": self.channels.ownership[opponent] * visible, + # "neutral_cells": self.channels.ownership_neutral * visible, + # "visible_cells": visible, + # "structures_in_fog": invisible * (self.channels.mountain + self.channels.city), + # "owned_land_count": info[agent]["land"], + # "owned_army_count": info[agent]["army"], + # "opponent_land_count": info[opponent]["land"], + # "opponent_army_count": info[opponent]["army"], + # "is_winner": int(info[agent]["is_winner"]), + # "timestep": self.time, + # } + # observation: Observation = { + # "observation": _observation, + # "action_mask": self.action_mask(agent), + # } + # + # return observation def agent_won(self, agent: str) -> bool: """ diff --git a/generals/core/observation.py b/generals/core/observation.py new file mode 100644 index 0000000..02b95f9 --- /dev/null +++ b/generals/core/observation.py @@ -0,0 +1,193 @@ +import numpy as np +from scipy.ndimage import maximum_filter # type: ignore + +from generals.core.game import DIRECTIONS, Game +from generals.remote.generalsio_client import GeneralsIOState + + +def observation_from_simulator(game: Game, agent_id: str) -> "Observation": + scores = {} + for agent in game.agents: + army_size = np.sum(game.channels.army * game.channels.ownership[agent]).astype(int) + land_size = np.sum(game.channels.ownership[agent]).astype(int) + scores[agent] = { + "army": army_size, + "land": land_size, + } + opponent = game.agents[0] if agent_id == game.agents[1] else game.agents[1] + visible = game.channels.get_visibility(agent_id) + invisible = 1 - visible + army = game.channels.army.astype(int) * visible + generals = game.channels.general * visible + city = game.channels.city * visible + owned_cells = game.channels.ownership[agent_id] * visible + opponent_cells = game.channels.ownership[opponent] * visible + neutral_cells = game.channels.ownership_neutral * visible + visible_cells = visible + structures_in_fog = invisible * (game.channels.mountain + game.channels.city) + owned_land_count = scores[agent_id]["land"] + owned_army_count = scores[agent_id]["army"] + opponent_land_count = scores[opponent]["land"] + opponent_army_count = scores[opponent]["army"] + timestep = game.time + + return Observation( + army=army, + generals=generals, + city=city, + owned_cells=owned_cells, + opponent_cells=opponent_cells, + neutral_cells=neutral_cells, + visible_cells=visible_cells, + structures_in_fog=structures_in_fog, + owned_land_count=owned_land_count, + owned_army_count=owned_army_count, + opponent_land_count=opponent_land_count, + opponent_army_count=opponent_army_count, + timestep=timestep, + ) + + +def observation_from_generalsio_state(state: GeneralsIOState) -> "Observation": + width, height = state.map[0], state.map[1] + size = height * width + + armies = np.array(state.map[2 : 2 + size]).reshape((height, width)) + terrain = np.array(state.map[2 + size : 2 + 2 * size]).reshape((height, width)) + cities = np.zeros((height, width)) + for city in state.cities: + cities[city // width, city % width] = 1 + + generals = np.zeros((height, width)) + for general in state.generals: + if general != -1: + generals[general // width, general % width] = 1 + + army = armies + owned_cells = np.where(terrain == state.player_index, 1, 0) + opponent_cells = np.where(terrain == state.opponent_index, 1, 0) + neutral_cells = np.where(terrain == -1, 1, 0) + visible_cells = maximum_filter(np.where(terrain == state.player_index, 1, 0), size=3) + structures_in_fog = np.where(terrain == -4, 1, 0) + owned_land_count = state.scores[state.player_index]["tiles"] + owned_army_count = state.scores[state.player_index]["total"] + opponent_land_count = state.scores[state.opponent_index]["tiles"] + opponent_army_count = state.scores[state.opponent_index]["total"] + timestep = state.turn + + return Observation( + army=army, + generals=generals, + city=cities, + owned_cells=owned_cells, + opponent_cells=opponent_cells, + neutral_cells=neutral_cells, + visible_cells=visible_cells, + structures_in_fog=structures_in_fog, + owned_land_count=owned_land_count, + owned_army_count=owned_army_count, + opponent_land_count=opponent_land_count, + opponent_army_count=opponent_army_count, + timestep=timestep, + ) + + +class Observation: + def __init__( + self, + army: np.ndarray, + generals: np.ndarray, + city: np.ndarray, + owned_cells: np.ndarray, + opponent_cells: np.ndarray, + neutral_cells: np.ndarray, + visible_cells: np.ndarray, + structures_in_fog: np.ndarray, + owned_land_count: int, + owned_army_count: int, + opponent_land_count: int, + opponent_army_count: int, + timestep: int, + ): + self.army = army + self.generals = generals + self.city = city + self.owned_cells = owned_cells + self.opponent_cells = opponent_cells + self.neutral_cells = neutral_cells + self.visible_cells = visible_cells + self.structures_in_fog = structures_in_fog + self.owned_land_count = owned_land_count + self.owned_army_count = owned_army_count + self.opponent_land_count = opponent_land_count + self.opponent_army_count = opponent_army_count + self.timestep = timestep + + def action_mask(self) -> np.ndarray: + """ + Function to compute valid actions from a given ownership mask. + + Valid action is an action that originates from agent's cell with atleast 2 units + and does not bump into a mountain or fall out of the grid. + Returns: + np.ndarray: an NxNx4 array, where each channel is a boolean mask + of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid. + + I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j). + """ + height, width = self.owned_cells.shape + + ownership_channel = self.owned_cells + more_than_1_army = (self.army > 1) * ownership_channel + owned_cells_indices = np.argwhere(more_than_1_army) + valid_action_mask = np.zeros((height, width, 4), dtype=bool) + + if np.sum(ownership_channel) == 0: + return valid_action_mask + + for channel_index, direction in enumerate(DIRECTIONS): + destinations = owned_cells_indices + direction.value + + # check if destination is in grid bounds + in_first_boundary = np.all(destinations >= 0, axis=1) + in_height_boundary = destinations[:, 0] < height + in_width_boundary = destinations[:, 1] < width + destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary] + + # check if destination is road + passable_cells = self.neutral_cells + self.owned_cells + self.opponent_cells + self.city + # assert that every value is either 0 or 1 in passable cells + assert np.all(np.isin(passable_cells, [0, 1])) + passable_cell_indices = passable_cells[destinations[:, 0], destinations[:, 1]] == 1 + action_destinations = destinations[passable_cell_indices] + + # get valid action mask for a given direction + valid_source_indices = action_destinations - direction.value + valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0 + + return valid_action_mask + + def as_dict(self, with_mask=True): + _obs = { + "armies": self.army, + "generals": self.generals, + "cities": self.city, + "owned_cells": self.owned_cells, + "opponent_cells": self.opponent_cells, + "neutral_cells": self.neutral_cells, + "visible_cells": self.visible_cells, + "structures_in_fog": self.structures_in_fog, + "owned_land_count": self.owned_land_count, + "owned_army_count": self.owned_army_count, + "opponent_land_count": self.opponent_land_count, + "opponent_army_count": self.opponent_army_count, + "timestep": self.timestep, + } + if with_mask: + obs = { + "observation": _obs, + "action_mask": self.action_mask(), + } + else: + obs = _obs + return obs diff --git a/generals/envs/gymnasium_generals.py b/generals/envs/gymnasium_generals.py index febaf91..f10642c 100644 --- a/generals/envs/gymnasium_generals.py +++ b/generals/envs/gymnasium_generals.py @@ -5,8 +5,9 @@ from generals.agents import Agent, AgentFactory from generals.core.config import Reward, RewardFn -from generals.core.game import Action, Game, Info, Observation +from generals.core.game import Action, Game, Info from generals.core.grid import GridFactory +from generals.core.observation import Observation, observation_from_simulator from generals.core.replay import Replay from generals.gui import GUI from generals.gui.properties import GuiMode @@ -90,16 +91,22 @@ def reset( self.observation_space = self.game.observation_space self.action_space = self.game.action_space - observation = self.game.agent_observation(self.agent_id) + observation = observation_from_simulator(self.game, self.agent_id).as_dict() info: dict[str, Any] = {} return observation, info def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: # Get action of NPC - npc_action = self.npc.act(self.game.agent_observation(self.npc.id)) + npc_ovservation = observation_from_simulator(self.game, self.npc.id).as_dict() + npc_action = self.npc.act(npc_ovservation) actions = {self.agent_id: action, self.npc.id: npc_action} - observations, infos = self.game.step(actions) + self.game.step(actions) + observations = { + agent_id: observation_from_simulator(self.game, agent_id).as_dict() for agent_id in self.agent_ids + } + infos = {agent_id: {} for agent_id in self.agent_ids} + # From observations of all agents, pick only those relevant for the main agent obs = observations[self.agent_id] info = infos[self.agent_id] diff --git a/generals/gui/rendering.py b/generals/gui/rendering.py index 96a4ce8..8eba516 100644 --- a/generals/gui/rendering.py +++ b/generals/gui/rendering.py @@ -225,7 +225,7 @@ def render_grid(self): # Draw nonzero army counts on visible squares visible_army = self.game.channels.army * visible_map - visible_army_indices = self.game.channel_to_indices(visible_army) + visible_army_indices = self.channel_to_indices(visible_army) for i, j in visible_army_indices: self.render_cell_text( self.tiles[i][j], @@ -240,12 +240,18 @@ def render_grid(self): self.game_area.blit(self.tiles[i][j], (j * square_size, i * square_size)) self.screen.blit(self.game_area, (0, 0)) + def channel_to_indices(self, channel: np.ndarray) -> np.ndarray: + """ + Returns a list of indices of cells with non-zero values from specified a channel. + """ + return np.argwhere(channel != 0) + def draw_channel(self, channel: np.ndarray, color: Color): """ Draw background and borders (left and top) for grid tiles of a given channel """ square_size = Dimension.SQUARE_SIZE.value - for i, j in self.game.channel_to_indices(channel): + for i, j in self.channel_to_indices(channel): self.tiles[i][j].fill(color) pygame.draw.line(self.tiles[i][j], BLACK, (0, 0), (0, square_size), 1) pygame.draw.line(self.tiles[i][j], BLACK, (0, 0), (square_size, 0), 1) @@ -254,5 +260,5 @@ def draw_images(self, channel: np.ndarray, image: pygame.Surface): """ Draw images on grid tiles of a given channel """ - for i, j in self.game.channel_to_indices(channel): + for i, j in self.channel_to_indices(channel): self.tiles[i][j].blit(image, (3, 2)) diff --git a/generals/remote/generalsio_client.py b/generals/remote/generalsio_client.py index b8fcb11..5b620f9 100644 --- a/generals/remote/generalsio_client.py +++ b/generals/remote/generalsio_client.py @@ -1,4 +1,5 @@ import numpy as np +from scipy.ndimage import maximum_filter # type: ignore from socketio import SimpleClient # type: ignore from generals.agents.agent import Agent @@ -43,12 +44,13 @@ def apply_diff(old: list[int], diff: list[int]) -> list[int]: i += 1 return new + test_old_1 = [0, 0] test_diff_1 = [1, 1, 3] -desired = [0,3] +desired = [0, 3] assert apply_diff(test_old_1, test_diff_1) == desired -test_old_2 = [0,0] -test_diff_2 = [0,1,2,1] +test_old_2 = [0, 0] +test_diff_2 = [0, 1, 2, 1] desired = [2, 0] assert apply_diff(test_old_2, test_diff_2) == desired print("All tests passed") @@ -59,7 +61,7 @@ def __init__(self, data: dict): self.replay_id = data["replay_id"] self.usernames = data["usernames"] self.player_index = data["playerIndex"] - self.opponent_index = 1 - self.player_index # works only for 1v1 + self.opponent_index = 1 - self.player_index # works only for 1v1 self.n_players = len(self.usernames) @@ -75,20 +77,41 @@ def update(self, data: dict) -> None: if "stars" in data: self.stars = data["stars"] - def agent_observation(self) -> Observation: width, height = self.map[0], self.map[1] size = height * width armies = np.array(self.map[2 : 2 + size]).reshape((height, width)) terrain = np.array(self.map[2 + size : 2 + 2 * size]).reshape((height, width)) - - # make 2D binary map of owned cells. These are the ones that have self.player_index value in terrain - army = armies - owned_cells = np.where(terrain == self.player_index, 1, 0) - opponent_cells = np.where(terrain == self.opponent_index, 1, 0) - visible_neutral_cells = np.where(terrain == -1, 1, 0) - print(self.generals) + cities = np.zeros((height, width)) + for city in self.cities: + cities[city // width, city % width] = 1 + + generals = np.zeros((height, width)) + for general in self.generals: + if general != -1: + generals[general // width, general % width] = 1 + _observation = { + "army": armies, + "general": generals, + "city": cities, + "owned_cells": np.where(terrain == self.player_index, 1, 0), + "opponent_cells": np.where(terrain == self.opponent_index, 1, 0), + "neutral_cells": np.where(terrain == -1, 1, 0), + "visible_cells": maximum_filter(np.where(terrain == self.player_index, 1, 0), size=3), + "structures_in_fog": np.where(terrain == -4, 1, 0), + "owned_land_count": self.scores[self.player_index]["tiles"], + "owned_army_count": self.scores[self.player_index]["total"], + "opponent_land_count": self.scores[self.opponent_index]["tiles"], + "opponent_army_count": self.scores[self.opponent_index]["total"], + "is_winner": False, + "timestep": self.turn, + } + + observation = { + "observation": _observation, + } + return observation class GeneralsIOClient(SimpleClient):