diff --git a/examples/complete_example.py b/examples/complete_example.py index 2ca9f83..2b89c65 100644 --- a/examples/complete_example.py +++ b/examples/complete_example.py @@ -17,8 +17,9 @@ env = gym.make( "gym-generals-v0", # Environment name grid_factory=grid_factory, # Grid factory - agent=agent, # Your agent (used to get metadata like name and color) npc=npc, # NPC that will play against the agent + agent_id="Agent", # Agent ID + agent_color=(67, 70, 86), # Agent color render_mode="human", # "human" mode is for rendering, None is for no rendering ) diff --git a/examples/gymnasium_example.py b/examples/gymnasium_example.py index 8974830..5ecc604 100644 --- a/examples/gymnasium_example.py +++ b/examples/gymnasium_example.py @@ -2,20 +2,14 @@ from generals import AgentFactory # Initialize agents -agent = AgentFactory.make_agent("expander") npc = AgentFactory.make_agent("random") -env = gym.make( - "gym-generals-v0", - agent=agent, - npc=npc, - render_mode="human", -) +# Create environment +env = gym.make("gym-generals-v0", npc=npc, render_mode="human") observation, info = env.reset() - terminated = truncated = False while not (terminated or truncated): - action = agent.act(observation) + action = env.action_space.sample() # Here you put your agent's action observation, reward, terminated, truncated, info = env.step(action) env.render() diff --git a/examples/pettingzoo_example.py b/examples/pettingzoo_example.py index d32358d..2badac9 100644 --- a/examples/pettingzoo_example.py +++ b/examples/pettingzoo_example.py @@ -6,13 +6,12 @@ expander = AgentFactory.make_agent("expander") agents = { - random.name: random, - expander.name: expander, + random.id: random, + expander.id: expander, } # Environment calls agents by name # Create environment -- render modes: {None, "human"} -# env = pz_generals(agents=agents, render_mode="human")D -env = gym.make("pz-generals-v0", agents=agents, render_mode="human") +env = gym.make("pz-generals-v0", agents=list(agents.keys()), render_mode="human") observations, info = env.reset() done = False @@ -25,4 +24,4 @@ # All agents perform their actions observations, rewards, terminated, truncated, info = env.step(actions) done = any(terminated.values()) or any(truncated.values()) - env.render(fps=6) + env.render() diff --git a/examples/record_replay_example.py b/examples/record_replay_example.py index faee6e0..f62f772 100644 --- a/examples/record_replay_example.py +++ b/examples/record_replay_example.py @@ -17,7 +17,8 @@ env = gym.make( "gym-generals-v0", # Environment name grid_factory=grid_factory, # Grid factory - agent=agent, # Your agent (used to get metadata like name and color) + agent_id="Agent", # Agent ID + agent_color=(67, 70, 86), # Agent color npc=npc, # NPC that will play against the agent ) diff --git a/generals/agents/agent.py b/generals/agents/agent.py index ec931a7..82e3436 100644 --- a/generals/agents/agent.py +++ b/generals/agents/agent.py @@ -6,8 +6,8 @@ class Agent(ABC): Base class for all agents. """ - def __init__(self, name="Agent", color=(67, 70, 86)): - self.name = name + def __init__(self, id="NPC", color=(67, 70, 86)): + self.id = id self.color = color @abstractmethod @@ -27,12 +27,4 @@ def reset(self): raise NotImplementedError def __str__(self): - return self.name - - -class EmptyAgent(Agent): - def act(self, observation): - return None - - def reset(self): - pass + return self.id diff --git a/generals/agents/agent_factory.py b/generals/agents/agent_factory.py index 84ec090..5b9a16b 100644 --- a/generals/agents/agent_factory.py +++ b/generals/agents/agent_factory.py @@ -1,6 +1,7 @@ from .random_agent import RandomAgent from .expander_agent import ExpanderAgent -from .agent import Agent, EmptyAgent +from .agent import Agent + class AgentFactory: """ @@ -19,7 +20,5 @@ def make_agent(agent_type: str, **kwargs) -> Agent: return RandomAgent(**kwargs) elif agent_type == "expander": return ExpanderAgent(**kwargs) - elif agent_type == "empty": - return EmptyAgent(**kwargs) else: raise ValueError(f"Unknown agent type: {agent_type}") diff --git a/generals/agents/expander_agent.py b/generals/agents/expander_agent.py index fec8e7a..1d05e37 100644 --- a/generals/agents/expander_agent.py +++ b/generals/agents/expander_agent.py @@ -5,8 +5,8 @@ class ExpanderAgent(Agent): - def __init__(self, name="Expander", color=(0, 130, 255)): - super().__init__(name, color) + def __init__(self, id="Expander", color=(0, 130, 255)): + super().__init__(id, color) def act(self, observation): """ @@ -30,7 +30,9 @@ def act(self, observation): directions = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT] for i, action in enumerate(valid_actions): - di, dj = action[:-1] + directions[action[-1]].value # Destination cell indices + di, dj = ( + action[:-1] + directions[action[-1]].value + ) # Destination cell indices if army[action[0], action[1]] <= army[di, dj] + 1: # Can't capture continue elif opponent[di, dj]: diff --git a/generals/agents/random_agent.py b/generals/agents/random_agent.py index daf69b4..9b60047 100644 --- a/generals/agents/random_agent.py +++ b/generals/agents/random_agent.py @@ -4,9 +4,9 @@ class RandomAgent(Agent): def __init__( - self, name="Random", color=(242, 61, 106), split_prob=0.25, idle_prob=0.05 + self, id="Random", color=(242, 61, 106), split_prob=0.25, idle_prob=0.05 ): - super().__init__(name, color) + super().__init__(id, color) self.idle_probability = idle_prob self.split_probability = split_prob diff --git a/generals/core/game.py b/generals/core/game.py index 5152532..e2e2e91 100644 --- a/generals/core/game.py +++ b/generals/core/game.py @@ -15,7 +15,7 @@ Action: TypeAlias = gym.Space Info: TypeAlias = dict[str, Any] -increment_rate = 50 + DIRECTIONS = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT] @@ -33,12 +33,16 @@ def __init__(self, grid: Grid, agents: list[str]): self.channels = Channels(map, self.agents) + # Constants + self.increment_rate = 50 + self.max_army_value = 10_000 + self.max_land_value = np.prod(self.grid_dims) + self.max_timestep = 100_000 ########## # Spaces # ########## - max_value = 100_000 grid_multi_binary = gym.spaces.MultiBinary(self.grid_dims) - grid_discrete = np.ones(self.grid_dims, dtype=int) * max_value + grid_discrete = np.ones(self.grid_dims, dtype=int) * self.max_army_value self.observation_space = gym.spaces.Dict( { "observation": gym.spaces.Dict( @@ -51,12 +55,12 @@ def __init__(self, grid: Grid, agents: list[str]): "neutral_cells": grid_multi_binary, "visible_cells": grid_multi_binary, "structure": grid_multi_binary, - "owned_land_count": gym.spaces.Discrete(max_value), - "owned_army_count": gym.spaces.Discrete(max_value), - "opponent_land_count": gym.spaces.Discrete(max_value), - "opponent_army_count": gym.spaces.Discrete(max_value), + "owned_land_count": gym.spaces.Discrete(self.max_army_value), + "owned_army_count": gym.spaces.Discrete(self.max_army_value), + "opponent_land_count": gym.spaces.Discrete(self.max_army_value), + "opponent_army_count": gym.spaces.Discrete(self.max_army_value), "is_winner": gym.spaces.Discrete(2), - "timestep": gym.spaces.Discrete(max_value), + "timestep": gym.spaces.Discrete(self.max_timestep), } ), "action_mask": gym.spaces.MultiBinary(self.grid_dims + (4,)), @@ -238,7 +242,7 @@ def get_all_observations(self) -> dict[str, Observation]: """ Returns observations for all agents. """ - return {agent: self._agent_observation(agent) for agent in self.agents} + return {agent: self.agent_observation(agent) for agent in self.agents} def _global_game_update(self) -> None: """ @@ -248,7 +252,7 @@ def _global_game_update(self) -> None: owners = self.agents # every `increment_rate` steps, increase army size in each cell - if self.time % increment_rate == 0: + if self.time % self.increment_rate == 0: for owner in owners: self.channels.army += self.channels.ownership[owner] @@ -284,7 +288,7 @@ def get_infos(self) -> dict[str, Info]: } return players_stats - def _agent_observation(self, agent: str) -> Observation: + def agent_observation(self, agent: str) -> Observation: """ Returns an observation for a given agent. Args: diff --git a/generals/envs/env.py b/generals/envs/env.py index 1fa1c40..7ce4522 100644 --- a/generals/envs/env.py +++ b/generals/envs/env.py @@ -1,5 +1,5 @@ -from .gymnasium_integration import Gym_Generals, RewardFn -from .pettingzoo_integration import PZ_Generals +from .gymnasium_generals import GymnasiumGenerals +from .pettingzoo_generals import PettingZooGenerals from generals.agents import Agent, AgentFactory from generals import GridFactory @@ -17,35 +17,36 @@ def pz_generals_v0( grid_factory: GridFactory = GridFactory(), - agents: dict[str, Agent] = None, - reward_fn: RewardFn = None, + agents: list[str] = None, render_mode=None, ): - assert len(agents) == 2, "Only 2 agents are supported in PZ_Generals." - env = PZ_Generals( + assert len(agents) == 2, "For now, only 2 agents are supported in PZ_Generals." + env = PettingZooGenerals( grid_factory=grid_factory, agents=agents, - reward_fn=reward_fn, render_mode=render_mode, ) return env - def gym_generals_v0( grid_factory: GridFactory = GridFactory(), - agent: Agent = None, npc: Agent = None, - reward_fn: RewardFn = None, render_mode=None, + reward_fn=None, + agent_id: str = "Agent", + agent_color: tuple[int, int, int] = (67, 70, 86), ): - if npc is None: - print("No npc provided, using RandomAgent.") - npc = AgentFactory.init_agent("random") - env = Gym_Generals( + if not isinstance(npc, Agent): + print( + "NPC must be an instance of Agent class, Creating random NPC as a fallback." + ) + npc = AgentFactory.make_agent("random") + env = GymnasiumGenerals( grid_factory=grid_factory, - agent=agent, npc=npc, - reward_fn=reward_fn, render_mode=render_mode, + agent_id=agent_id, + agent_color=agent_color, + reward_fn=reward_fn, ) return env diff --git a/generals/envs/gymnasium_integration.py b/generals/envs/gymnasium_generals.py similarity index 61% rename from generals/envs/gymnasium_integration.py rename to generals/envs/gymnasium_generals.py index a0d479a..2602a84 100644 --- a/generals/envs/gymnasium_integration.py +++ b/generals/envs/gymnasium_generals.py @@ -5,18 +5,20 @@ import functools from copy import deepcopy -from generals.agents import Agent from generals.core.game import Game, Action, Observation, Info from generals.core.grid import GridFactory -from generals.gui import GUI from generals.core.replay import Replay +from generals.agents import Agent +from generals.gui import GUI +from generals.gui.properties import GuiMode # Type aliases Reward: TypeAlias = float RewardFn: TypeAlias = Callable[[dict[str, Observation], Action, bool, Info], Reward] +AgentID: TypeAlias = str -class Gym_Generals(gym.Env): +class GymnasiumGenerals(gym.Env): metadata = { "render_modes": ["human"], "render_fps": 6, @@ -25,43 +27,50 @@ class Gym_Generals(gym.Env): def __init__( self, grid_factory: GridFactory = None, - agent: Agent = None, npc: Agent = None, - reward_fn: RewardFn = None, + reward_fn=None, render_mode=None, + agent_id: str = "Agent", + agent_color: tuple[int, int, int] = (67, 70, 86), ): self.render_mode = render_mode - self.reward_fn = self._default_reward if reward_fn is None else reward_fn self.grid_factory = grid_factory + if reward_fn is not None: + self.reward_fn = reward_fn + else: + self.reward_fn = GymnasiumGenerals._default_reward - assert isinstance(agent, Agent), "Agent must be an instance of Agent class." + # Agents assert isinstance(npc, Agent), "NPC must be an instance of Agent class." - self.agent = agent self.npc = npc - - self.agent_data = {agent.name: {"color": agent.color} for agent in [agent, npc]} - - # Check whether agents have unique names + self.agent_id = agent_id + self.agent_ids = [self.agent_id, self.npc.id] + self.agent_data = { + agent_id: {"color": agent_color}, + self.npc.id: {"color": self.npc.color}, + } assert ( - agent.name != npc.name - ), "Agent names must be unique - you can pass custom names to agent constructors." + agent_id != npc.id + ), "Agent ids must be unique - you can pass custom ids to agent constructors." + # Game grid = self.grid_factory.grid_from_generator() - game = Game(grid, [self.agent.name, self.npc.name]) - self.observation_space = game.observation_space - self.action_space = game.action_space + self.game = Game(grid, [self.agent_id, self.npc.id]) + self.observation_space = self.game.observation_space + self.action_space = self.game.action_space @functools.lru_cache(maxsize=None) def observation_space(self) -> gym.Space: - return self.game.observation_space + return self.observation_space @functools.lru_cache(maxsize=None) def action_space(self) -> gym.Space: - return self.game.action_space + return self.action_space - def render(self): + def render(self, fps: int = None) -> None: + fps = self.metadata["render_fps"] if fps is None else fps if self.render_mode == "human": - _ = self.gui.tick(fps=self.metadata["render_fps"]) + _ = self.gui.tick(fps=fps) def reset( self, seed: int | None = None, options: dict[str, Any] | None = None @@ -69,20 +78,15 @@ def reset( if options is None: options = {} super().reset(seed=seed) - # If map is not provided, generate a new one if "grid" in options: grid = self.grid_factory.grid_from_string(options["grid"]) else: grid = self.grid_factory.grid_from_generator(seed=seed) - self.game = Game(grid, [self.agent.name, self.npc.name]) - self.npc.reset() - - self.observation_space = self.game.observation_space - self.action_space = self.game.action_space + self.game = Game(grid, self.agent_ids) if self.render_mode == "human": - self.gui = GUI(self.game, self.agent_data) + self.gui = GUI(self.game, self.agent_data, GuiMode.TRAIN) if "replay_file" in options: self.replay = Replay( @@ -94,24 +98,26 @@ def reset( elif hasattr(self, "replay"): del self.replay - observation = self.game._agent_observation(self.agent.name) + self.observation_space = self.game.observation_space + self.action_space = self.game.action_space + + observation = self.game.agent_observation(self.agent_id) info = {} return observation, info def step( self, action: Action ) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: - # get action of NPC - npc_action = self.npc.act(self.game._agent_observation(self.npc.name)) - actions = {self.agent.name: action, self.npc.name: npc_action} + # Get action of NPC + npc_action = self.npc.act(self.game.agent_observation(self.npc.id)) + actions = {self.agent_id: action, self.npc.id: npc_action} observations, infos = self.game.step(actions) - observation = observations[self.agent.name] - info = infos[self.agent.name] - truncated = False - terminated = True if self.game.is_done() else False - done = terminated or truncated - reward = self.reward_fn(observation, action, done, info) + obs = observations[self.agent_id] + info = infos[self.agent_id] + reward = self.reward_fn(obs, action, self.game.is_done(), info) + terminated = self.game.is_done() + truncated = False if self.game.time < 120 else True # Choose your constant if hasattr(self, "replay"): self.replay.add_state(deepcopy(self.game.channels)) @@ -119,8 +125,7 @@ def step( if terminated: if hasattr(self, "replay"): self.replay.store() - - return observation, reward, terminated, truncated, info + return obs, reward, terminated, truncated, info @staticmethod def _default_reward( @@ -130,7 +135,6 @@ def _default_reward( info: Info, ) -> Reward: """ - Calculate rewards for each agent. Give 0 if game still running, otherwise 1 for winner and -1 for loser. """ if done: @@ -140,5 +144,5 @@ def _default_reward( return reward def close(self) -> None: - if hasattr(self, "gui"): + if hasattr(self, "replay"): self.gui.close() diff --git a/generals/envs/gymnasium_wrappers.py b/generals/envs/gymnasium_wrappers.py new file mode 100644 index 0000000..31305ae --- /dev/null +++ b/generals/envs/gymnasium_wrappers.py @@ -0,0 +1,145 @@ +import gymnasium as gym +import numpy as np + + +class NormalizeObservationWrapper(gym.ObservationWrapper): + def __init__(self, env): + super(NormalizeObservationWrapper, self).__init__(env) + grid_multi_binary = gym.spaces.MultiBinary(self.game.grid_dims) + unit_box = gym.spaces.Box(low=0, high=1, dtype=np.float32) + self.observation_space = gym.spaces.Dict( + { + "observation": gym.spaces.Dict( + { + "army": gym.spaces.Box( + low=0, high=1, shape=self.game.grid_dims, dtype=np.float32 + ), + "general": grid_multi_binary, + "city": grid_multi_binary, + "owned_cells": grid_multi_binary, + "opponent_cells": grid_multi_binary, + "neutral_cells": grid_multi_binary, + "visible_cells": grid_multi_binary, + "structure": grid_multi_binary, + "owned_land_count": unit_box, + "owned_army_count": unit_box, + "opponent_land_count": unit_box, + "opponent_army_count": unit_box, + "is_winner": gym.spaces.Discrete(2), + "timestep": unit_box, + } + ), + "action_mask": gym.spaces.MultiBinary(self.game.grid_dims + (4,)), + } + ) + + def observation(self, observation): + game = self.game + _observation = ( + observation["observation"] if "observation" in observation else observation + ) + _observation["army"] = np.array( + _observation["army"] / game.max_army_value, dtype=np.float32 + ) + _observation["timestep"] = np.array( + [_observation["timestep"] / game.max_timestep], dtype=np.float32 + ) + _observation["owned_land_count"] = np.array( + [_observation["owned_land_count"] / game.max_land_value], dtype=np.float32 + ) + _observation["opponent_land_count"] = np.array( + [_observation["opponent_land_count"] / game.max_land_value], + dtype=np.float32, + ) + _observation["owned_army_count"] = np.array( + [_observation["owned_army_count"] / game.max_army_value], dtype=np.float32 + ) + _observation["opponent_army_count"] = np.array( + [_observation["opponent_army_count"] / game.max_army_value], + dtype=np.float32, + ) + observation["observation"] = _observation + return observation + + +class RemoveActionMaskWrapper(gym.ObservationWrapper): + def __init__(self, env): + super(RemoveActionMaskWrapper, self).__init__(env) + grid_multi_binary = gym.spaces.MultiBinary(self.game.grid_dims) + unit_box = gym.spaces.Box(low=0, high=1, dtype=np.float32) + self.observation_space = gym.spaces.Dict( + { + "army": gym.spaces.Box( + low=0, high=1, shape=self.game.grid_dims, dtype=np.float32 + ), + "general": grid_multi_binary, + "city": grid_multi_binary, + "owned_cells": grid_multi_binary, + "opponent_cells": grid_multi_binary, + "neutral_cells": grid_multi_binary, + "visible_cells": grid_multi_binary, + "structure": grid_multi_binary, + "owned_land_count": unit_box, + "owned_army_count": unit_box, + "opponent_land_count": unit_box, + "opponent_army_count": unit_box, + "is_winner": gym.spaces.Discrete(2), + "timestep": unit_box, + } + ) + + def observation(self, observation): + _observation = ( + observation["observation"] if "observation" in observation else observation + ) + return _observation + + +class ObservationAsImageWrapper(gym.ObservationWrapper): + def __init__(self, env): + super(ObservationAsImageWrapper, self).__init__(env) + self.observation_space = gym.spaces.Box( + low=0, high=1, shape=self.game.grid_dims + (14,), dtype=np.float32 + ) + + def observation(self, observation): + _observation = ( + observation["observation"] if "observation" in observation else observation + ) + # broadcast owned_land_count and other unit_boxes to the shape of the grid + _owned_land_count = np.broadcast_to( + _observation["owned_land_count"], self.game.grid_dims + ) + _owned_army_count = np.broadcast_to( + _observation["owned_army_count"], self.game.grid_dims + ) + _opponent_land_count = np.broadcast_to( + _observation["opponent_land_count"], self.game.grid_dims + ) + _opponent_army_count = np.broadcast_to( + _observation["opponent_army_count"], self.game.grid_dims + ) + _is_winner = np.broadcast_to(_observation["is_winner"], self.game.grid_dims) + _timestep = np.broadcast_to(_observation["timestep"], self.game.grid_dims) + _observation = np.stack( + [ + _observation["army"], + _observation["general"], + _observation["city"], + _observation["owned_cells"], + _observation["opponent_cells"], + _observation["neutral_cells"], + _observation["visible_cells"], + _observation["structure"], + _owned_land_count, + _owned_army_count, + _opponent_land_count, + _opponent_army_count, + _is_winner, + _timestep, + ], + dtype=np.float32, + axis=-1, + ) + _observation = np.moveaxis(_observation, -1, 0) + return _observation diff --git a/generals/envs/pettingzoo_integration.py b/generals/envs/pettingzoo_generals.py similarity index 82% rename from generals/envs/pettingzoo_integration.py rename to generals/envs/pettingzoo_generals.py index c791127..1cf610d 100644 --- a/generals/envs/pettingzoo_integration.py +++ b/generals/envs/pettingzoo_generals.py @@ -6,52 +6,56 @@ from gymnasium import spaces from copy import deepcopy -from pettingzoo.utils.env import AgentID - from generals.core.game import Game, Action, Observation, Info from generals.core.grid import GridFactory -from generals.agents import Agent +from generals.core.replay import Replay from generals.gui import GUI from generals.gui.properties import GuiMode -from generals.core.replay import Replay - # Type aliases Reward: TypeAlias = float RewardFn: TypeAlias = Callable[[dict[str, Observation], Action, bool, Info], Reward] +AgentID: TypeAlias = str -class PZ_Generals(pettingzoo.ParallelEnv): +class PettingZooGenerals(pettingzoo.ParallelEnv): metadata = { "render_modes": ["human"], "render_fps": 6, } + default_colors = [ + (67, 70, 86), + (242, 61, 106), + (0, 255, 0), + (0, 0, 255), + ] # Up for improvement (needs to be extended for multiple agents) def __init__( self, grid_factory: GridFactory, - agents: dict[str, Agent], + agents: list[str], reward_fn: RewardFn = None, render_mode=None, ): - self.game = None - self.gui = None - self.replay = None - self.render_mode = render_mode self.grid_factory = grid_factory + if reward_fn is not None: + self.reward_fn = reward_fn + else: + self.reward_fn = PettingZooGenerals._default_reward self.agent_data = { - agents[agent].name: {"color": agents[agent].color} - for agent in agents.keys() + agent_id: {"color": color} + for agent_id, color in zip(agents, self.default_colors) } - self.possible_agents = list(agents.keys()) + self.agents = agents + self.possible_agents = agents - assert ( - len(self.possible_agents) == len(set(self.possible_agents)) - ), "Agent names must be unique - you can pass custom names to agent constructors." + assert len(self.possible_agents) == len( + set(self.possible_agents) + ), "Agent ids must be unique - you can pass custom ids to agent constructors." - self.reward_fn = self._default_reward if reward_fn is None else reward_fn + self.reward_fn = self._default_reward @functools.lru_cache(maxsize=None) def observation_space(self, agent: AgentID) -> spaces.Space: @@ -74,7 +78,6 @@ def reset( if options is None: options = {} self.agents = deepcopy(self.possible_agents) - if "grid" in options: grid = self.grid_factory.grid_from_string(options["grid"]) else: @@ -109,12 +112,10 @@ def step( dict[AgentID, Info], ]: observations, infos = self.game.step(actions) - truncated = {agent: False for agent in self.agents} # no truncation terminated = { agent: True if self.game.is_done() else False for agent in self.agents } - rewards = { agent: self.reward_fn( observations[agent], @@ -134,7 +135,6 @@ def step( self.agents = [] if hasattr(self, "replay"): self.replay.store() - return observations, rewards, terminated, truncated, infos @staticmethod @@ -154,4 +154,5 @@ def _default_reward( return reward def close(self) -> None: - self.gui.close() + if self.render_mode == "human": + self.gui.close() diff --git a/generals/gui/rendering.py b/generals/gui/rendering.py index 9c46f6d..def3d98 100644 --- a/generals/gui/rendering.py +++ b/generals/gui/rendering.py @@ -189,8 +189,8 @@ def render_grid(self): """ agents = self.game.agents # Maps of all owned and visible cells - owned_map = np.zeros((self.grid_height, self.grid_width), dtype=np.bool) - visible_map = np.zeros((self.grid_height, self.grid_width), dtype=np.bool) + owned_map = np.zeros((self.grid_height, self.grid_width), dtype=bool) + visible_map = np.zeros((self.grid_height, self.grid_width), dtype=bool) for agent in agents: ownership = self.game.channels.ownership[agent] owned_map = np.logical_or(owned_map, ownership) diff --git a/tests/gym_check.py b/tests/gym_check.py index ae5be8d..d3d9e6f 100644 --- a/tests/gym_check.py +++ b/tests/gym_check.py @@ -2,13 +2,12 @@ import gymnasium.utils.env_checker as env_checker from generals.agents import AgentFactory -agent = AgentFactory.make_agent("expander", name="A") -npc = AgentFactory.make_agent("random", name="B") +npc = AgentFactory.make_agent("random") env = gym.make( "gym-generals-v0", - agent=agent, + agent_id="tester", npc=npc, ) env_checker.check_env(env.unwrapped) -print('Gymnasium check passed!') +print("Gymnasium check passed!") diff --git a/tests/parallel_api_check.py b/tests/parallel_api_check.py index a5b622e..f26988e 100644 --- a/tests/parallel_api_check.py +++ b/tests/parallel_api_check.py @@ -1,10 +1,9 @@ from __future__ import annotations -from generals import pz_generals -from generals.agents import RandomAgent -from generals.core.grid import GridFactory +from generals import GridFactory, AgentFactory import warnings import numpy as np +import gymnasium as gym from pettingzoo.test.api_test import missing_attr_warning from pettingzoo.utils.conversions import ( @@ -141,13 +140,13 @@ def parallel_api_test(par_env: ParallelEnv, num_cycles=1000): if __name__ == "__main__": mapper = GridFactory() - agent1 = RandomAgent(name="A") - agent2 = RandomAgent(name="B") + agent1 = AgentFactory.make_agent("expander", id="A") + agent2 = AgentFactory.make_agent("random", id="B") agents = { - agent1.name: agent1, - agent2.name: agent2, + agent1.id: agent1, + agent2.id: agent2, } - env = pz_generals(mapper, agents) + env = gym.make("pz-generals-v0", agents=list(agents.keys()), grid_factory=mapper) # test the environment with parallel_api_test import time start = time.time() diff --git a/tests/test_game.py b/tests/test_game.py index 0c232cd..b05d2eb 100644 --- a/tests/test_game.py +++ b/tests/test_game.py @@ -227,7 +227,7 @@ def test_observations(): ############ # TEST RED # ############ - red_observation = game._agent_observation("red")["observation"] + red_observation = game.agent_observation("red")["observation"] reference_opponent_ownership = np.array( [ [0, 0, 0, 0], @@ -286,7 +286,7 @@ def test_observations(): ############# # TEST BLUE # ############# - blue_observation = game._agent_observation("blue")["observation"] + blue_observation = game.agent_observation("blue")["observation"] reference_opponent_ownership = np.array( [ [0, 0, 0, 0],