Skip to content

Commit

Permalink
Merge pull request #62 from strakam/redo-maps
Browse files Browse the repository at this point in the history
refactor: From Map to GridFactory
  • Loading branch information
strakam authored Oct 1, 2024
2 parents 200fb99 + 92be9cf commit d8999e1
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 274 deletions.
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ replay:
at:
pytest tests/test_game.py
pytest tests/test_map.py
pytest tests/test_replay.py
python3 tests/gym_check.py
python3 tests/sb3_check.py

Expand Down
6 changes: 3 additions & 3 deletions examples/complete_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from generals.env import pz_generals
from generals.agents import RandomAgent, ExpanderAgent
from generals.map import Mapper
from generals.grid import GridFactory

# Initialize agents - their names are then called for actions
randomer = RandomAgent("Random1", color=(255, 125, 0))
Expand All @@ -11,7 +11,7 @@
expander.name: expander,
}

mapper = Mapper(
gf = GridFactory(
grid_dims=(4, 8), # height x width
mountain_density=0.2,
city_density=0.05,
Expand All @@ -27,7 +27,7 @@
"""

# Create environment
env = pz_generals(mapper, agents, render_mode=None) # Disable rendering
env = pz_generals(gf, agents, render_mode=None) # Disable rendering

options = {
"map": map,
Expand Down
17 changes: 10 additions & 7 deletions generals/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,39 @@
from .integrations.gymnasium_integration import Gym_Generals, RewardFn
from .integrations.pettingzoo_integration import PZ_Generals

from .map import Mapper
from .grid import GridFactory


def pz_generals(
mapper: Mapper = Mapper(),
grid_factory: GridFactory = GridFactory(),
agents: dict[str, Agent] = None,
reward_fn: RewardFn=None,
reward_fn: RewardFn = None,
render_mode=None,
):
"""
Here we apply wrappers to the environment.
"""
env = PZ_Generals(
mapper=mapper, agents=agents, reward_fn=reward_fn, render_mode=render_mode
grid_factory=grid_factory,
agents=agents,
reward_fn=reward_fn,
render_mode=render_mode,
)
return env


def gym_generals(
mapper: Mapper = Mapper(),
grid_factory: GridFactory = GridFactory(),
agent: Agent = None,
npc: Agent = None,
reward_fn: RewardFn=None,
reward_fn: RewardFn = None,
render_mode=None,
):
"""
Here we apply wrappers to the environment.
"""
env = Gym_Generals(
mapper=mapper,
grid_factory=grid_factory,
agent=agent,
npc=npc,
reward_fn=reward_fn,
Expand Down
15 changes: 8 additions & 7 deletions generals/game.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Any

from generals.grid import Grid
import numpy as np
import gymnasium as gym
from typing_extensions import TypeAlias
Expand All @@ -14,19 +15,17 @@


class Game:
def __init__(self, map: np.ndarray, agents: list[str]):
def __init__(self, grid: Grid, agents: list[str]):
self.agents = agents
map = grid.numpified_grid
self.grid_dims = (map.shape[0], map.shape[1])
self.map = map
self.time = 0

self.general_positions = {
agent: np.argwhere(map == chr(ord("A") + i))[0]
for i, agent in enumerate(self.agents)
}

valid_generals = ["A", "B"] # because generals are represented as letters

# Initialize channels
# Army - army size in each cell
# General - general mask (1 if general is in cell, 0 otherwise)
Expand All @@ -36,6 +35,8 @@ def __init__(self, map: np.ndarray, agents: list[str]):
# Ownership_i - ownership mask for player i (1 if player i owns cell, 0 otherwise)
# Ownerhsip_0 - ownership mask for neutral cells that are passable (1 if cell is neutral, 0 otherwise)
# Initialize channels

valid_generals = ["A", "B"] # Generals are represented by A and B
self.channels = {
"army": np.where(np.isin(map, valid_generals), 1, 0).astype(int),
"general": np.where(np.isin(map, valid_generals), 1, 0).astype(bool),
Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(self, map: np.ndarray, agents: list[str]):
"opponent_land_count": gym.spaces.Discrete(max_value),
"opponent_army_count": gym.spaces.Discrete(max_value),
"is_winner": gym.spaces.Discrete(2),
"timestep": gym.spaces.Discrete(max_value)
"timestep": gym.spaces.Discrete(max_value),
}
),
"action_mask": gym.spaces.MultiBinary(self.grid_dims + (4,)),
Expand Down Expand Up @@ -335,9 +336,9 @@ def _agent_observation(self, agent: str) -> Observation:
}
observation = {
"observation": _observation,
"action_mask": self.action_mask(agent)
"action_mask": self.action_mask(agent),
}

return observation

def agent_won(self, agent: str) -> bool:
Expand Down
142 changes: 142 additions & 0 deletions generals/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import numpy as np
from generals.config import PASSABLE, MOUNTAIN


class Grid:
def __init__(self, grid: str):
if not Grid.verify_grid(grid):
raise ValueError("Invalid grid layout - generals cannot reach each other.")
self._grid_string = grid.strip()

@property
def grid(self):
return self._grid_string

@grid.setter
def grid(self, grid: str):
grid = grid.strip()
if not Grid.verify_grid(grid):
raise ValueError("Invalid grid layout - generals cannot reach each other.")
self._grid_string = grid

@property
def numpified_grid(self):
return Grid.numpify_grid(self._grid_string)

@staticmethod
def numpify_grid(grid: str) -> np.ndarray:
return np.array([list(row) for row in grid.strip().split("\n")])

@staticmethod
def stringify_grid(grid: np.ndarray) -> str:
return "\n".join(["".join(row) for row in grid])

@staticmethod
def verify_grid(grid: str) -> bool:
"""
Verify grid layout (can generals reach each other?)
Returns True if grid is valid, False otherwise
"""

def dfs(grid, visited, square):
i, j = square
if (
i < 0
or i >= grid.shape[0]
or j < 0
or j >= grid.shape[1]
or visited[i, j]
):
return
if grid[i, j] == MOUNTAIN:
return
visited[i, j] = True
for di, dj in [[-1, 0], [1, 0], [0, -1], [0, 1]]:
new_square = (i + di, j + dj)
dfs(grid, visited, new_square)

grid = Grid.numpify_grid(grid)
generals = np.argwhere(np.isin(grid, ["A", "B"]))
start, end = generals[0], generals[1]
visited = np.zeros_like(grid, dtype=bool)
dfs(grid, visited, start)
return visited[end[0], end[1]]

def __str__(self):
return self._grid_string


class GridFactory:
def __init__(
self,
grid_dims: tuple[int, int] = (10, 10),
mountain_density: float = 0.2,
city_density: float = 0.05,
general_positions: list[tuple[int, int]] = None,
seed: int = None,
):
self.grid_height = grid_dims[0]
self.grid_width = grid_dims[1]
self.mountain_density = mountain_density
self.city_density = city_density
self.general_positions = general_positions
self.seed = seed

def grid_from_string(self, grid: str) -> Grid:
return Grid(grid)

def grid_from_generator(
self,
grid_dims: tuple[int, int] = None,
mountain_density: float = None,
city_density: float = None,
general_positions: list[tuple[int, int]] = None,
seed: int = None,
) -> Grid:
if grid_dims is None:
grid_dims = (self.grid_height, self.grid_width)
if mountain_density is None:
mountain_density = self.mountain_density
if city_density is None:
city_density = self.city_density
if general_positions is None:
general_positions = self.general_positions
if seed is None:
seed = self.seed

# Probabilities of each cell type
p_neutral = 1 - mountain_density - city_density
probs = [p_neutral, mountain_density] + [city_density / 10] * 10

# Place cells on the map
rng = np.random.default_rng(seed)
map = rng.choice(
[PASSABLE, MOUNTAIN, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
size=grid_dims,
p=probs,
)

# Place generals on random squares - generals_positions is a list of two tuples
if general_positions is None:
general_positions = []
while len(general_positions) < 2:
position = tuple(rng.integers(0, grid_dims))
if position not in general_positions:
general_positions.append(position)

for i, idx in enumerate(general_positions):
map[idx[0], idx[1]] = chr(ord("A") + i)

# Convert map to string
map_string = "\n".join(["".join(row) for row in map.astype(str)])

try:
return Grid(map_string)
except ValueError:
return self.grid_from_generator(
grid_dims=grid_dims,
mountain_density=mountain_density,
city_density=city_density,
general_positions=general_positions,
seed=seed,
)
22 changes: 10 additions & 12 deletions generals/integrations/gymnasium_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from ..agents import Agent
from ..game import Game, Observation
from ..gui import GUI
from ..map import Mapper
from ..grid import GridFactory
from ..replay import Replay

# Type aliases
Expand All @@ -20,15 +19,15 @@
class Gym_Generals(gym.Env):
def __init__(
self,
mapper: Mapper,
grid_factory: GridFactory,
agent: Agent,
npc: Agent,
reward_fn: RewardFn = None,
render_mode=None,
):
self.render_mode = render_mode
self.reward_fn = self.default_reward if reward_fn is None else reward_fn
self.mapper = mapper
self.grid_factory = grid_factory

self.agent_name = agent.name
self.npc = npc
Expand All @@ -40,8 +39,8 @@ def __init__(
agent.name != npc.name
), "Agent names must be unique - you can pass custom names to agent constructors."

map = self.mapper.get_map(numpify=True)
game = Game(map, [self.agent_name, self.npc.name])
grid = self.grid_factory.grid_from_generator()
game = Game(grid, [self.agent_name, self.npc.name])
self.observation_space = game.observation_space
self.action_space = game.action_space

Expand All @@ -62,13 +61,12 @@ def reset(self, seed=None, options=None):
options = {}
super().reset(seed=seed)
# If map is not provided, generate a new one
if "map" in options:
map = options["map"]
if "grid" in options:
grid = self.grid_factory.grid_from_string(options["grid"])
else:
self.mapper.reset() # Generate new map
map = self.mapper.get_map()
grid = self.grid_factory.grid_from_generator()

self.game = Game(self.mapper.numpify_map(map), [self.agent_name, self.npc.name])
self.game = Game(grid, [self.agent_name, self.npc.name])
self.npc.reset()

self.observation_space = self.game.observation_space
Expand All @@ -80,7 +78,7 @@ def reset(self, seed=None, options=None):
if "replay_file" in options:
self.replay = Replay(
name=options["replay_file"],
map=map,
grid=grid,
agent_data=self.agent_colors,
)
self.replay.add_state(deepcopy(self.game.channels))
Expand Down
Loading

0 comments on commit d8999e1

Please sign in to comment.