From 4c05741eb94560e8a47f93a9168ba9945b0e8a1a Mon Sep 17 00:00:00 2001 From: Matej Straka Date: Sun, 12 Jan 2025 17:32:47 +0100 Subject: [PATCH] chore: Reset to stable state --- .github/workflows/code-embedder.yml | 9 +- .pre-commit-config.yaml | 8 - README.md | 58 +++--- examples/gymnasium_example.py | 35 ++-- examples/vectorized_multiagent_gymnasium.py | 29 --- generals/__init__.py | 25 +-- generals/envs/__init__.py | 2 - generals/envs/gymnasium_generals.py | 133 ++++++------- generals/envs/gymnasium_wrappers.py | 55 ----- generals/envs/initializers.py | 61 ------ .../envs/multiagent_gymnasium_generals.py | 188 ------------------ pyproject.toml | 4 - tests/test_gym.py | 15 -- 13 files changed, 133 insertions(+), 489 deletions(-) delete mode 100644 examples/vectorized_multiagent_gymnasium.py delete mode 100644 generals/envs/gymnasium_wrappers.py delete mode 100644 generals/envs/initializers.py delete mode 100644 generals/envs/multiagent_gymnasium_generals.py delete mode 100644 tests/test_gym.py diff --git a/.github/workflows/code-embedder.yml b/.github/workflows/code-embedder.yml index 0d7f1a7..2e277e1 100644 --- a/.github/workflows/code-embedder.yml +++ b/.github/workflows/code-embedder.yml @@ -2,15 +2,20 @@ name: Code Embedder on: pull_request +permissions: + contents: write + jobs: code_embedder: name: "Code embedder" runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} - name: Run code embedder - uses: kvankova/code-embedder@v0.4.0 + uses: kvankova/code-embedder@v1.1.1 env: GITHUB_TOKEN: ${{ secrets.CODE_EMBEDDER }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc7fd8c..df3fe74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,14 +30,6 @@ repos: types: - python - - id: mypy - name: mypy - entry: poetry run mypy . - pass_filenames: false - language: system - types: - - python - - id: run-pytest name: Run tests entry: make test diff --git a/README.md b/README.md index a9595c9..4e67039 100644 --- a/README.md +++ b/README.md @@ -46,29 +46,37 @@ pip install -e . ## 🌱 Getting Started Creating an agent is very simple. Start by subclassing an `Agent` class just like [`RandomAgent`](./generals/agents/random_agent.py) or [`ExpanderAgent`](./generals/agents/expander_agent.py). -You can specify your agent `id` (name) and `color` and the only thing remaining is to implement the `act` function, +You can specify your agent `id` (name) and the only thing remaining is to implement the `act` function, that has the signature explained in sections down below. -### Usage Example (🤸 Gymnasium) +### Usage Example (🦁 PettingZoo) The example loop for running the game looks like this -```python:examples/gymnasium_example.py -import gymnasium as gym - +```python:examples/pettingzoo_example.py from generals.agents import RandomAgent, ExpanderAgent +from generals.envs import PettingZooGenerals # Initialize agents -agent = RandomAgent() -npc = ExpanderAgent() +random = RandomAgent(id="Random") +expander = ExpanderAgent(id="Expander") + +# Store agents in a dictionary +agents = { + random.id: random, + expander.id: expander +} # Create environment -env = gym.make("gym-generals-v0", agent=agent, npc=npc, render_mode="human") +env = PettingZooGenerals(agent_ids=[random.id, expander.id], to_render=True) +observations, info = env.reset() -observation, info = env.reset() terminated = truncated = False while not (terminated or truncated): - action = agent.act(observation) - observation, reward, terminated, truncated, info = env.step(action) + actions = {} + for agent in env.agent_ids: + actions[agent] = agents[agent].act(observations[agent]) + # All agents perform their actions + observations, rewards, terminated, truncated, info = env.step(actions) env.render() ``` @@ -80,7 +88,7 @@ while not (terminated or truncated): Grids on which the game is played on are generated via `GridFactory`. You can instantiate the class with desired grid properties, and it will generate grid with these properties for each run. ```python -import gymnasium as gym +from generals.envs import PettingZooGenerals from generals import GridFactory grid_factory = GridFactory( @@ -89,21 +97,20 @@ grid_factory = GridFactory( mountain_density=0.2, # Probability of a mountain in a cell city_density=0.05, # Probability of a city in a cell general_positions=[(0,3),(5,7)], # Positions of generals (i, j) - padding=True # Whether to pad grids to max_grid_dims with mountains (default=True) ) # Create environment -env = gym.make( - "gym-generals-v0", +env = PettingZooGenerals( grid_factory=grid_factory, ... ) ``` You can also specify grids manually, as a string via `options` dict: ```python -import gymnasium as gym +from generals.envs import PettingZooGenerals + +env = PettingZooGenerals(agent_ids=[agent1.id, agent2.id]) -env = gym.make("gym-generals-v0", ...) grid = """ .3.# #..A @@ -123,13 +130,14 @@ Grids are created using a string format where: - numbers `0-9` and `x`, where `x=10`, represent cities, where the number specifies amount of neutral army in the city, which is calculated as `40 + number`. The reason for `x=10` is that the official game has cities in range `[40, 50]` +> [!TIP] +> Check out [complete example](./examples/complete_example.py) for concrete example in the wild! + ## 🔬 Interactive Replays We can store replays and then analyze them in an interactive fashion. `Replay` class handles replay related functionality. ### Storing a replay ```python -import gymnasium as gym - -env = gym.make("gym-generals-v0", ...) +env = ... options = {"replay_file": "my_replay"} env.reset(options=options) # The next game will be encoded in my_replay.pkl @@ -220,15 +228,19 @@ or create your own connection workflow. Our implementations expect that your age implemented the required methods. ```python:examples/client_example.py from generals.remote import autopilot +from generals.agents import ExpanderAgent + import argparse + parser = argparse.ArgumentParser() parser.add_argument("--user_id", type=str, default=...) # Register yourself at generals.io and use this id -parser.add_argument("--lobby_id", type=str, default=...) # The last part of the lobby url -parser.add_argument("--agent_id", type=str, default="Expander") # agent_id should be "registered" in AgentFactory +parser.add_argument("--lobby_id", type=str, default="psyo") # After you create a private lobby, copy last part of the url if __name__ == "__main__": args = parser.parse_args() - autopilot(args.agent_id, args.user_id, args.lobby_id) + agent = ExpanderAgent() + autopilot(agent, args.user_id, args.lobby_id) + ``` This script will run `ExpanderAgent` in the specified lobby. ## 🙌 Contributing diff --git a/examples/gymnasium_example.py b/examples/gymnasium_example.py index d5b81d7..c9d204a 100644 --- a/examples/gymnasium_example.py +++ b/examples/gymnasium_example.py @@ -1,17 +1,28 @@ +# type: ignore import gymnasium as gym +import numpy as np -from generals.agents import RandomAgent, ExpanderAgent +from generals.envs import GymnasiumGenerals -# Initialize agents -agent = RandomAgent() -npc = ExpanderAgent() +agent_names = ["007", "Generalissimo"] -# Create environment -env = gym.make("gym-generals-v0", agent=agent, npc=npc, render_mode="human") +n_envs = 12 +envs = gym.vector.AsyncVectorEnv( + [lambda: GymnasiumGenerals(agents=agent_names, truncation=500) for _ in range(n_envs)], +) -observation, info = env.reset() -terminated = truncated = False -while not (terminated or truncated): - action = agent.act(observation) - observation, reward, terminated, truncated, info = env.step(action) - env.render() + +observations, infos = envs.reset() +terminated = [False] * len(observations) +truncated = [False] * len(observations) + +while True: + agent_actions = [envs.single_action_space.sample() for _ in range(n_envs)] + npc_actions = [envs.single_action_space.sample() for _ in range(n_envs)] + + # Stack actions together + actions = np.stack([agent_actions, npc_actions], axis=1) + observations, rewards, terminated, truncated, infos = envs.step(actions) + masks = [np.stack([info[-1] for info in infos[agent_name]]) for agent_name in agent_names] + if any(terminated) or any(truncated): + break diff --git a/examples/vectorized_multiagent_gymnasium.py b/examples/vectorized_multiagent_gymnasium.py deleted file mode 100644 index e79b519..0000000 --- a/examples/vectorized_multiagent_gymnasium.py +++ /dev/null @@ -1,29 +0,0 @@ -# type: ignore -import gymnasium as gym -import numpy as np - -from generals.envs import MultiAgentGymnasiumGenerals - -agent_names = ["007", "Generalissimo"] - -n_envs = 12 -envs = gym.vector.AsyncVectorEnv( - [lambda: MultiAgentGymnasiumGenerals(agents=agent_names, truncation=500) for _ in range(n_envs)], -) - - -observations, infos = envs.reset() -terminated = [False] * len(observations) -truncated = [False] * len(observations) - -while True: - agent_actions = [envs.single_action_space.sample() for _ in range(n_envs)] - npc_actions = [envs.single_action_space.sample() for _ in range(n_envs)] - - # Stack actions together - actions = np.stack([agent_actions, npc_actions], axis=1) - observations, rewards, terminated, truncated, infos = envs.step(actions) - masks = [np.stack([info[-1] for info in infos[agent_name]]) for agent_name in agent_names] - if any(terminated) or any(truncated): - print("DONE!") - break diff --git a/generals/__init__.py b/generals/__init__.py index 1a53bea..c1f123b 100644 --- a/generals/__init__.py +++ b/generals/__init__.py @@ -1,10 +1,9 @@ -from gymnasium.envs.registration import register - from generals.agents.agent import Agent from generals.core.game import Action from generals.core.grid import Grid, GridFactory from generals.core.observation import Observation from generals.core.replay import Replay +from generals.envs.gymnasium_generals import GymnasiumGenerals from generals.envs.pettingzoo_generals import PettingZooGenerals __all__ = [ @@ -12,28 +11,8 @@ "Agent", "GridFactory", "PettingZooGenerals", + "GymnasiumGenerals", "Grid", "Replay", "Observation", - "GeneralsIOClientError", ] - - -def _register_gym_generals_envs(): - register( - id="gym-generals-v0", - entry_point="generals.envs.gymnasium_generals:GymnasiumGenerals", - ) - - register( - id="gym-generals-image-v0", - entry_point="generals.envs.initializers:gym_image_observations", - ) - - register( - id="gym-generals-rllib-v0", - entry_point="generals.envs.initializers:gym_rllib", - ) - - -_register_gym_generals_envs() diff --git a/generals/envs/__init__.py b/generals/envs/__init__.py index e5008a9..896361a 100644 --- a/generals/envs/__init__.py +++ b/generals/envs/__init__.py @@ -1,10 +1,8 @@ # type: ignore from generals.envs.gymnasium_generals import GymnasiumGenerals -from generals.envs.multiagent_gymnasium_generals import MultiAgentGymnasiumGenerals from generals.envs.pettingzoo_generals import PettingZooGenerals __all__ = [ "PettingZooGenerals", "GymnasiumGenerals", - "MultiAgentGymnasiumGenerals", ] diff --git a/generals/envs/gymnasium_generals.py b/generals/envs/gymnasium_generals.py index 172f98f..7cf0d9b 100644 --- a/generals/envs/gymnasium_generals.py +++ b/generals/envs/gymnasium_generals.py @@ -1,12 +1,13 @@ +# type: ignore from copy import deepcopy -from typing import Any, SupportsFloat +from typing import Any import gymnasium as gym import numpy as np from gymnasium import spaces -from generals.agents import Agent, RandomAgent -from generals.core.game import Action, Game +from generals.core.action import Action, compute_valid_move_mask +from generals.core.game import Game from generals.core.grid import Grid, GridFactory from generals.core.observation import Observation from generals.core.replay import Replay @@ -24,9 +25,8 @@ class GymnasiumGenerals(gym.Env): def __init__( self, + agents: list[str], grid_factory: GridFactory | None = None, - npc: Agent | None = None, - agent: Agent | None = None, # Optional, just to obtain id and color truncation: int | None = None, reward_fn: RewardFn | None = None, render_mode: str | None = None, @@ -35,24 +35,15 @@ def __init__( self.grid_factory = grid_factory if grid_factory is not None else GridFactory() self.reward_fn = reward_fn if reward_fn is not None else WinLoseRewardFn() # Observation for the agent at the prior time-step. - self.prior_observation: None | Observation = None + self.prior_observations: None | dict[str, Observation] = None - # Agents - if npc is None: - print('No NPC agent provided. Creating "Random" NPC as a fallback.') - npc = RandomAgent() - else: - assert isinstance(npc, Agent), "NPC must be an instance of Agent class." - self.npc = npc - self.agent_id = "Agent" if agent is None else agent.id - self.agent_ids = [self.agent_id, self.npc.id] + self.agents = agents self.colors = [(255, 107, 108), (0, 130, 255)] - self.agent_data = {id: {"color": color} for id, color in zip(self.agent_ids, self.colors)} - assert self.agent_id != npc.id, "Agent ids must be unique - you can pass custom ids to agent constructors." + self.agent_data = {id: {"color": color} for id, color in zip(agents, self.colors)} # Game grid = self.grid_factory.generate() - self.game = Game(grid, [self.agent_id, self.npc.id]) + self.game = Game(grid, agents) self.truncation = truncation self.observation_space = self.set_observation_space() self.action_space = self.set_action_space() @@ -70,30 +61,8 @@ def set_observation_space(self) -> spaces.Space: dims = self.grid_factory.max_grid_dims else: dims = self.game.grid_dims - max_army_value = 100_000 - max_timestep = 100_000 - max_land_value = np.prod(dims) - grid_multi_binary = spaces.MultiBinary(dims) - grid_discrete = np.ones(dims, dtype=int) * 100_000 - return spaces.Dict( - { - "armies": spaces.MultiDiscrete(grid_discrete), - "generals": grid_multi_binary, - "cities": grid_multi_binary, - "mountains": grid_multi_binary, - "neutral_cells": grid_multi_binary, - "owned_cells": grid_multi_binary, - "opponent_cells": grid_multi_binary, - "fog_cells": grid_multi_binary, - "structures_in_fog": grid_multi_binary, - "owned_land_count": spaces.Discrete(max_land_value), - "owned_army_count": spaces.Discrete(max_army_value), - "opponent_land_count": spaces.Discrete(max_land_value), - "opponent_army_count": spaces.Discrete(max_army_value), - "timestep": spaces.Discrete(max_timestep), - "priority": spaces.Discrete(2), - } - ) + + return spaces.Box(low=0, high=2**31 - 1, shape=(2, 15, dims[0], dims[1]), dtype=np.float32) def set_action_space(self) -> spaces.Space: if self.grid_factory.padding: @@ -122,7 +91,7 @@ def reset( grid = self.grid_factory.generate() # Create game for current run - self.game = Game(grid, self.agent_ids) + self.game = Game(grid, self.agents) # Create GUI for current render run if self.render_mode == "human": @@ -138,33 +107,63 @@ def reset( elif hasattr(self, "replay"): del self.replay - observation = self.game.agent_observation(self.agent_id) - info: dict[str, Any] = {} - return observation, info - - def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: - # Get action of NPC - npc_observation = self.game.agent_observation(self.npc.id) - npc_action = self.npc.act(npc_observation) - actions = {self.agent_id: action, self.npc.id: npc_action} - - observations, infos = self.game.step(actions) + _obs = {agent: self.game.agent_observation(agent) for agent in self.agents} + observations = np.stack([_obs[agent].as_tensor() for agent in self.agents], dtype=np.float32) + + infos: dict[str, Any] = self.game.get_infos() + # flatten infos + infos = { + agent: [ + infos[agent]["army"], + infos[agent]["land"], + infos[agent]["is_done"], + infos[agent]["is_winner"], + compute_valid_move_mask(_obs[agent]), + ] + for i, agent in enumerate(self.agents) + } + return observations, infos + + def step(self, actions: list[Action]) -> tuple[Any, Any, bool, bool, dict[str, Any]]: + _actions = { + self.agents[0]: actions[0], + self.agents[1]: actions[1], + } + + observations, infos = self.game.step(_actions) + obs1 = self.game.agent_observation(self.agents[0]).as_tensor() + obs2 = self.game.agent_observation(self.agents[1]).as_tensor() + obs = np.stack([obs1, obs2]) + + # flatten infos + infos = { + agent: [ + infos[agent]["army"], + infos[agent]["land"], + infos[agent]["is_done"], + infos[agent]["is_winner"], + compute_valid_move_mask(observations[agent]), + ] + for agent in self.agents + } # From observations of all agents, pick only those relevant for the main agent - obs = observations[self.agent_id] - info = infos[self.agent_id] - if self.prior_observation is None: + if self.prior_observations is None: # Cannot compute a reward without a prior-observation. This should only happen # on the first time-step. - reward = 0.0 + rewards = [0.0, 0.0] else: - reward = self.reward_fn( - prior_obs=self.prior_observation, - # Technically, action is the prior-action, since it's what gives rise to the - # current observation. - prior_action=action, - obs=obs, - ) + rewards = [ + self.reward_fn( + prior_obs=self.prior_observations[agent], + # Technically, action is the prior-action, since it's what gives rise to the + # current observation. + prior_action=_actions[agent], + obs=observations[agent], + ) + for agent in self.agents + ] + rewards = 0 # WIP terminated = self.game.is_done() truncated = False @@ -181,8 +180,8 @@ def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, self.observation_space = self.set_observation_space() self.action_space = self.set_action_space() - self.prior_observation = observations[self.agent_id] - return obs, reward, terminated, truncated, info + self.prior_observations = {agent: observations[agent] for agent in self.agents} + return obs, rewards, terminated, truncated, infos def close(self) -> None: if self.render_mode == "human": diff --git a/generals/envs/gymnasium_wrappers.py b/generals/envs/gymnasium_wrappers.py deleted file mode 100644 index ecc9d5a..0000000 --- a/generals/envs/gymnasium_wrappers.py +++ /dev/null @@ -1,55 +0,0 @@ -import gymnasium as gym -import numpy as np - - -class RemoveActionMaskWrapper(gym.ObservationWrapper): - def __init__(self, env): - super().__init__(env) - self.observation_space = env.observation_space["observation"] - - def observation(self, observation): - _observation = observation["observation"] if "observation" in observation else observation - return _observation - - -class ObservationAsImageWrapper(gym.ObservationWrapper): - def __init__(self, env): - super().__init__(env) - n_obs_keys = len(self.observation_space["observation"].items()) - self.game = env.game - self.observation_space = gym.spaces.Dict( - { - "observation": gym.spaces.Box( - low=0, high=1, shape=self.game.grid_dims + (n_obs_keys,), dtype=np.float32 - ), - "action_mask": gym.spaces.MultiBinary(self.game.grid_dims + (4,)), - } - ) - - def observation(self, observation): - game = self.game - _obs = observation["observation"] if "observation" in observation else observation - _obs = ( - np.stack( - [ - _obs["armies"] / game.max_army_value, - _obs["generals"], - _obs["cities"], - _obs["mountains"], - _obs["neutral_cells"], - _obs["owned_cells"], - _obs["opponent_cells"], - _obs["fog_cells"], - _obs["structures_in_fog"], - np.ones(game.grid_dims) * _obs["owned_land_count"] / game.max_land_value, - np.ones(game.grid_dims) * _obs["owned_army_count"] / game.max_army_value, - np.ones(game.grid_dims) * _obs["opponent_land_count"] / game.max_land_value, - np.ones(game.grid_dims) * _obs["opponent_army_count"] / game.max_army_value, - np.ones(game.grid_dims) * _obs["timestep"] / game.max_timestep, - np.ones(game.grid_dims) * _obs["priority"], - ] - ) - .astype(np.float32) - .transpose(1, 2, 0) - ) - return _obs diff --git a/generals/envs/initializers.py b/generals/envs/initializers.py deleted file mode 100644 index 36ea0b7..0000000 --- a/generals/envs/initializers.py +++ /dev/null @@ -1,61 +0,0 @@ -from generals import GridFactory -from generals.agents import Agent -from generals.envs.gymnasium_generals import GymnasiumGenerals -from generals.envs.gymnasium_wrappers import ObservationAsImageWrapper, RemoveActionMaskWrapper -from generals.rewards.reward_fn import RewardFn - -""" -Here we can define environment initialization functions that -can create interesting types of environments. In case of -Gymnasium environments, please register these functions also -in the generals/__init__.py, so they can be created via gym.make ----------------------------------------------------------- -Feel free to add more initializers here. It is a good place -to create "pre-wrapped" envs, or envs with custom maps or other -custom settings. -""" - - -def gym_image_observations( - grid_factory: GridFactory | None = None, - npc: Agent | None = None, - agent: Agent | None = None, - render_mode: str | None = None, - reward_fn: RewardFn | None = None, -): - """ - Example of a Gymnasium environment initializer that creates - an environment that returns image observations. - """ - _env = GymnasiumGenerals( - grid_factory=grid_factory, - npc=npc, - agent=agent, - render_mode=render_mode, - reward_fn=reward_fn, - ) - env = ObservationAsImageWrapper(_env) - return env - - -def gym_rllib( - grid_factory: GridFactory | None = None, - npc: Agent | None = None, - agent: Agent | None = None, - render_mode: str | None = None, - reward_fn: RewardFn | None = None, -): - """ - Example of a Gymnasium environment initializer that creates - an environment that returns image observations. - """ - env = GymnasiumGenerals( - grid_factory=grid_factory, - npc=npc, - agent=agent, - render_mode=render_mode, - reward_fn=reward_fn, - ) - image_env = ObservationAsImageWrapper(env) - no_action_env = RemoveActionMaskWrapper(image_env) - return no_action_env diff --git a/generals/envs/multiagent_gymnasium_generals.py b/generals/envs/multiagent_gymnasium_generals.py deleted file mode 100644 index e81daf8..0000000 --- a/generals/envs/multiagent_gymnasium_generals.py +++ /dev/null @@ -1,188 +0,0 @@ -# type: ignore -from copy import deepcopy -from typing import Any - -import gymnasium as gym -import numpy as np -from gymnasium import spaces - -from generals.core.action import Action, compute_valid_move_mask -from generals.core.game import Game -from generals.core.grid import Grid, GridFactory -from generals.core.observation import Observation -from generals.core.replay import Replay -from generals.gui import GUI -from generals.gui.properties import GuiMode -from generals.rewards.reward_fn import RewardFn -from generals.rewards.win_lose_reward_fn import WinLoseRewardFn - - -class MultiAgentGymnasiumGenerals(gym.Env): - metadata = { - "render_modes": ["human"], - "render_fps": 6, - } - - def __init__( - self, - agents: list[str], - grid_factory: GridFactory | None = None, - truncation: int | None = None, - reward_fn: RewardFn | None = None, - render_mode: str | None = None, - ): - self.render_mode = render_mode - self.grid_factory = grid_factory if grid_factory is not None else GridFactory() - self.reward_fn = reward_fn if reward_fn is not None else WinLoseRewardFn() - # Observation for the agent at the prior time-step. - self.prior_observations: None | dict[str, Observation] = None - - self.agents = agents - self.colors = [(255, 107, 108), (0, 130, 255)] - self.agent_data = {id: {"color": color} for id, color in zip(agents, self.colors)} - - # Game - grid = self.grid_factory.generate() - self.game = Game(grid, agents) - self.truncation = truncation - self.observation_space = self.set_observation_space() - self.action_space = self.set_action_space() - - def set_observation_space(self) -> spaces.Space: - """ - If grid_factory has padding on, grid (and therefore observations) will be padded to the same shape, - which corresponds to the maximum grid dimensions of grid_factory. - Otherwise, the observatoin shape might change depending on the currently generated grid. - - Note: The grid is padded with mountains from right and bottom. We recommend using the padded - grids for training purposes, as it will make the observations consistent across episodes. - """ - if self.grid_factory.padding: - dims = self.grid_factory.max_grid_dims - else: - dims = self.game.grid_dims - - return spaces.Box(low=0, high=2**31 - 1, shape=(2, 15, dims[0], dims[1]), dtype=np.float32) - - def set_action_space(self) -> spaces.Space: - if self.grid_factory.padding: - dims = self.grid_factory.max_grid_dims - else: - dims = self.game.grid_dims - return spaces.MultiDiscrete([2, dims[0], dims[1], 4, 2]) - - def render(self): - if self.render_mode == "human": - _ = self.gui.tick(fps=self.metadata["render_fps"]) - - def reset( - self, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[Observation, dict[str, Any]]: - super().reset(seed=seed) - if options is None: - options = {} - - if "grid" in options: - grid = Grid(options["grid"]) - else: - # Provide the np.random.Generator instance created in Env.reset() - # as opposed to creating a new one with the same seed. - self.grid_factory.set_rng(rng=self.np_random) - grid = self.grid_factory.generate() - - # Create game for current run - self.game = Game(grid, self.agents) - - # Create GUI for current render run - if self.render_mode == "human": - self.gui = GUI(self.game, self.agent_data, GuiMode.TRAIN) - - if "replay_file" in options: - self.replay = Replay( - name=options["replay_file"], - grid=grid, - agent_data=self.agent_data, - ) - self.replay.add_state(deepcopy(self.game.channels)) - elif hasattr(self, "replay"): - del self.replay - - _obs = {agent: self.game.agent_observation(agent) for agent in self.agents} - observations = np.stack([_obs[agent].as_tensor() for agent in self.agents], dtype=np.float32) - - infos: dict[str, Any] = self.game.get_infos() - # flatten infos - infos = { - agent: [ - infos[agent]["army"], - infos[agent]["land"], - infos[agent]["is_done"], - infos[agent]["is_winner"], - compute_valid_move_mask(_obs[agent]), - ] - for i, agent in enumerate(self.agents) - } - return observations, infos - - def step(self, actions: list[Action]) -> tuple[Any, Any, bool, bool, dict[str, Any]]: - _actions = { - self.agents[0]: actions[0], - self.agents[1]: actions[1], - } - - observations, infos = self.game.step(_actions) - obs1 = self.game.agent_observation(self.agents[0]).as_tensor() - obs2 = self.game.agent_observation(self.agents[1]).as_tensor() - obs = np.stack([obs1, obs2]) - - # flatten infos - infos = { - agent: [ - infos[agent]["army"], - infos[agent]["land"], - infos[agent]["is_done"], - infos[agent]["is_winner"], - compute_valid_move_mask(observations[agent]), - ] - for agent in self.agents - } - - # From observations of all agents, pick only those relevant for the main agent - if self.prior_observations is None: - # Cannot compute a reward without a prior-observation. This should only happen - # on the first time-step. - rewards = [0.0, 0.0] - else: - rewards = [ - self.reward_fn( - prior_obs=self.prior_observations[agent], - # Technically, action is the prior-action, since it's what gives rise to the - # current observation. - prior_action=_actions[agent], - obs=observations[agent], - ) - for agent in self.agents - ] - rewards = 0 # WIP - - terminated = self.game.is_done() - truncated = False - if self.truncation is not None: - truncated = self.game.time >= self.truncation - - if hasattr(self, "replay"): - self.replay.add_state(deepcopy(self.game.channels)) - - if terminated or truncated: - if hasattr(self, "replay"): - self.replay.store() - - self.observation_space = self.set_observation_space() - self.action_space = self.set_action_space() - - self.prior_observations = {agent: observations[agent] for agent in self.agents} - return obs, rewards, terminated, truncated, infos - - def close(self) -> None: - if self.render_mode == "human": - self.gui.close() diff --git a/pyproject.toml b/pyproject.toml index c174df6..2cb54e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,3 @@ line-length = 120 lint.select = ["E", "F", "I", "BLE", "UP", "FA"] target-version = "py311" exclude = ["tests", "examples"] - -[tool.mypy] -exclude = ["tests/", "examples/", "generals/agents/"] -follow_imports="silent" diff --git a/tests/test_gym.py b/tests/test_gym.py deleted file mode 100644 index c39afe2..0000000 --- a/tests/test_gym.py +++ /dev/null @@ -1,15 +0,0 @@ -import gymnasium as gym -import gymnasium.utils.env_checker as env_checker - -from generals.agents import RandomAgent - - -def test_gym_runs(): - npc = RandomAgent() - - env = gym.make( - "gym-generals-v0", - npc=npc, - ) - env_checker.check_env(env.unwrapped) - print("Gymnasium check passed!")