From 0732d3769e1784e833dfcb25c4ffa31a999e8528 Mon Sep 17 00:00:00 2001 From: Matej Straka Date: Mon, 14 Oct 2024 19:59:53 +0200 Subject: [PATCH] refactor: Improve some imports, remove whitespace --- README.md | 2 +- examples/gymnasium_example.py | 6 +++--- examples/pettingzoo_example.py | 9 +++++---- generals/__init__.py | 2 +- generals/agents/__init__.py | 10 +++++++++- generals/core/config.py | 5 ++--- 6 files changed, 21 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 210655e..2c65899 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,7 @@ The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N | `is_winner` | — | Indicates whether the agent won | | `timestep` | — | Current timestep of the game | -The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell +The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell `[i, j]` in one of four directions: `0 (up)`, `1 (down)`, `2 (left)`, or `3 (right)`. ### ⚡ Action diff --git a/examples/gymnasium_example.py b/examples/gymnasium_example.py index 34b22b6..d2e7b0d 100644 --- a/examples/gymnasium_example.py +++ b/examples/gymnasium_example.py @@ -2,8 +2,8 @@ from generals import AgentFactory -# Initialize opponent agent ("random" or "expander") -npc = AgentFactory.make_agent("random") +# Initialize opponent agent +npc = AgentFactory.make_agent("expander") # Create environment env = gym.make("gym-generals-v0", npc=npc, render_mode="human") @@ -11,6 +11,6 @@ observation, info = env.reset() terminated = truncated = False while not (terminated or truncated): - action = env.action_space.sample() # Here you put your agent's action + action = env.action_space.sample() # Here you put an action of your agent observation, reward, terminated, truncated, info = env.step(action) env.render() diff --git a/examples/pettingzoo_example.py b/examples/pettingzoo_example.py index 69afbbb..7576386 100644 --- a/examples/pettingzoo_example.py +++ b/examples/pettingzoo_example.py @@ -1,15 +1,16 @@ -from generals.agents import AgentFactory +from generals.agents import RandomAgent, ExpanderAgent from generals.envs import PettingZooGenerals # Initialize agents -random = AgentFactory.make_agent("random") -expander = AgentFactory.make_agent("expander") +random = RandomAgent() +expander = ExpanderAgent() +# Store agents in a dictionary agents = { random.id: random, expander.id: expander, } -agent_ids = list(agents.keys()) # Environment calls agents by name +agent_ids = list(agents.keys()) # Create environment env = PettingZooGenerals(agents=agent_ids, render_mode="human") diff --git a/generals/__init__.py b/generals/__init__.py index 6da94a8..121227c 100644 --- a/generals/__init__.py +++ b/generals/__init__.py @@ -1,8 +1,8 @@ from gymnasium.envs.registration import register -from generals.agents.agent_factory import AgentFactory from generals.core.grid import Grid, GridFactory from generals.core.replay import Replay +from generals.agents import AgentFactory from generals.envs.pettingzoo_generals import PettingZooGenerals __all__ = [ diff --git a/generals/agents/__init__.py b/generals/agents/__init__.py index 43dfa76..3812681 100644 --- a/generals/agents/__init__.py +++ b/generals/agents/__init__.py @@ -2,6 +2,14 @@ from .agent import Agent from .agent_factory import AgentFactory +from .expander_agent import ExpanderAgent +from .random_agent import RandomAgent # You can also define an __all__ list if you want to restrict what gets imported with * -__all__ = ["Agent", "AgentFactory"] +__all__ = [ + "Agent", + "AgentFactory", + "RandomAgent", + "ExpanderAgent", + "AgentFactory", +] diff --git a/generals/core/config.py b/generals/core/config.py index 208f720..513b65f 100644 --- a/generals/core/config.py +++ b/generals/core/config.py @@ -5,6 +5,7 @@ import gymnasium as gym import numpy as np +# Type aliases Observation: TypeAlias = dict[str, np.ndarray | dict[str, gym.Space]] Action: TypeAlias = dict[str, int | np.ndarray] Info: TypeAlias = dict[str, Any] @@ -13,9 +14,7 @@ RewardFn: TypeAlias = Callable[[Observation, Action, bool, Info], Reward] AgentID: TypeAlias = str -################# -# Game Literals # -################# +# Game Literals PASSABLE: Literal["."] = "." MOUNTAIN: Literal["#"] = "#"