Skip to content

Commit

Permalink
Merge pull request #47 from strakam/restructure-rendering
Browse files Browse the repository at this point in the history
Restructure code
  • Loading branch information
strakam authored Sep 28, 2024
2 parents cd0bf3b + 3c4a5bb commit d40fcf6
Show file tree
Hide file tree
Showing 39 changed files with 133 additions and 130 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pip install -e .
## Usage Example (🦁 PettingZoo)
```python
from generals.env import pz_generals
from generals.agent import ExpanderAgent, RandomAgent
from generals.agents import ExpanderAgent, RandomAgent

# Initialize agents
random = RandomAgent()
Expand Down Expand Up @@ -76,7 +76,7 @@ while not env.game.is_done():
## Usage example (🤸 Gymnasium)
```python
from generals.env import gym_generals
from generals.agent import RandomAgent, ExpanderAgent
from generals.agents import RandomAgent, ExpanderAgent

# Initialize agents
agent = RandomAgent()
Expand Down
16 changes: 0 additions & 16 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,12 @@
## TODOs

### Features
- [x] Let user pass argument indicating what is the default opponent agent for Gymnasium environment.
- [x] Extend action space by allowing user to be IDLE.
- [ ] Add human control, so people can play against bots.
- [x] Random agent new parameters

### Improvements
- [x] Let user change game speed even when the game is paused; make 'Paused' text separate in the side panel.
- [x] Revisit types for observation space (np.float32 vs np.bool)
- [x] Make ExpanderAgent a bit more readable if possible
- [x] Test IDLE actions
- [x] Should we error out when agent tries to perform an invalid move, so that it is easier to debug?
- [x] Redo how replays are stored and loaded
- [x] In config, resolve circular dependency in a cleaner manner
- [ ] Implement .close() method in envs and instead of quitting in renderer, quit in env

### Bug fixes

### Documentation and CI
- [ ] Create more examples of usage (Stable Baselines3 demo)
- [x] Use gymnasium check_env
- [x] Pre-commit hooks for conventional commit checks (enforcing conventional commits)
- [x] Add CI for running tests (pre commit)
- [x] Add CI passing badge to README
- [x] Document agent action/move format
- [x] Split game step tests into more specific tests
4 changes: 2 additions & 2 deletions examples/complete_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from generals.env import pz_generals
from generals.agent import RandomAgent, ExpanderAgent
from generals.agents import RandomAgent, ExpanderAgent
from generals.map import Mapper

# Initialize agents - their names are then called for actions
Expand All @@ -13,7 +13,7 @@

# Mapper will be default generate 4x4 maps
mapper = Mapper(
grid_size=4,
grid_dims=(4, 8), # width x height
mountain_density=0.2,
city_density=0.05,
general_positions=[(0, 0), (3, 3)],
Expand Down
2 changes: 1 addition & 1 deletion examples/gymnasium_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from generals.env import gym_generals
from generals.agent import RandomAgent, ExpanderAgent
from generals.agents import RandomAgent, ExpanderAgent

# Initialize agents
agent = RandomAgent()
Expand Down
2 changes: 1 addition & 1 deletion examples/pettingzoo_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from generals.env import pz_generals
from generals.agent import ExpanderAgent, RandomAgent
from generals.agents import ExpanderAgent, RandomAgent

# Initialize agents
random = RandomAgent()
Expand Down
8 changes: 8 additions & 0 deletions generals/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# agents/__init__.py

from .random_agent import RandomAgent
from .expander_agent import ExpanderAgent
from .agent import Agent

# You can also define an __all__ list if you want to restrict what gets imported with *
__all__ = ["Agent", "RandomAgent", "ExpanderAgent"]
24 changes: 24 additions & 0 deletions generals/agents/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
class Agent:
"""
Base class for all agents.
"""

def __init__(self):
pass

def play(self):
"""
This method should be implemented by the child class.
It should receive an observation and return an action.
"""
raise NotImplementedError

def reset(self):
"""
This method allows the agent to reset its state.
If not needed, just pass.
"""
raise NotImplementedError

def __str__(self):
return self.name
58 changes: 1 addition & 57 deletions generals/agent.py → generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,9 @@
import numpy as np
from .agent import Agent

from generals.config import DIRECTIONS


class Agent:
"""
Base class for all agents.
"""

def __init__(self):
pass

def play(self):
"""
This method should be implemented by the child class.
It should receive an observation and return an action.
"""
raise NotImplementedError

def reset(self):
"""
This method allows the agent to reset its state.
If not needed, just pass.
"""
raise NotImplementedError

def __str__(self):
return self.name


class RandomAgent(Agent):
def __init__(
self, idle_prob=0.05, split_prob=0.25, name="Random", color=(255, 0, 0)
):
self.name = name
self.color = color

self.idle_probability = idle_prob
self.split_probability = split_prob

def play(self, observation):
"""
Randomly selects a valid action.
"""
mask = observation["action_mask"]
valid_actions = np.argwhere(mask == 1)
if len(valid_actions) == 0: # No valid actions
return np.array([1, 0, 0, 0, 0]) # Pass the move

pass_turn = [0] if np.random.rand() > self.idle_probability else [1]
split_army = [0] if np.random.rand() > self.split_probability else [1]

action_index = np.random.choice(len(valid_actions))

action = np.concatenate((pass_turn, valid_actions[action_index], split_army))
return action

def reset(self):
pass


class ExpanderAgent(Agent):
def __init__(self, name="Expander", color=(0, 130, 255)):
self.name = name
Expand Down
32 changes: 32 additions & 0 deletions generals/agents/random_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .agent import Agent
import numpy as np

class RandomAgent(Agent):
def __init__(
self, idle_prob=0.05, split_prob=0.25, name="Random", color=(255, 0, 0)
):
self.name = name
self.color = color

self.idle_probability = idle_prob
self.split_probability = split_prob

def play(self, observation):
"""
Randomly selects a valid action.
"""
mask = observation["action_mask"]
valid_actions = np.argwhere(mask == 1)
if len(valid_actions) == 0: # No valid actions
return np.array([1, 0, 0, 0, 0]) # Pass the move

pass_turn = [0] if np.random.rand() > self.idle_probability else [1]
split_army = [0] if np.random.rand() > self.split_probability else [1]

action_index = np.random.choice(len(valid_actions))

action = np.concatenate((pass_turn, valid_actions[action_index], split_army))
return action

def reset(self):
pass
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
8 changes: 4 additions & 4 deletions generals/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
FONT_TYPE = "Quicksand-Medium.ttf" # Font options are Quicksand-SemiBold.ttf, Quicksand-Medium.ttf, Quicksand-Light.ttf
FONT_SIZE = 18
try:
file_ref = files("generals.fonts") / FONT_TYPE
file_ref = files("generals.assets.fonts") / FONT_TYPE
FONT_PATH = str(file_ref)
except FileNotFoundError:
raise FileNotFoundError(f"Font file {FONT_TYPE} not found in the fonts directory")
Expand All @@ -59,14 +59,14 @@
# Icons #
#########
try:
GENERAL_PATH = str(files("generals.images") / "crownie.png")
GENERAL_PATH = str(files("generals.assets.images") / "crownie.png")
except FileNotFoundError:
raise FileNotFoundError("Image not found")
try:
CITY_PATH = str(files("generals.images") / "citie.png")
CITY_PATH = str(files("generals.assets.images") / "citie.png")
except FileNotFoundError:
raise FileNotFoundError("Image not found")
try:
MOUNTAIN_PATH = str(files("generals.images") / "mountainie.png")
MOUNTAIN_PATH = str(files("generals.assets.images") / "mountainie.png")
except FileNotFoundError:
raise FileNotFoundError("Image not found")
25 changes: 15 additions & 10 deletions generals/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ def __init__(self, map: np.ndarray, agents: List[str]):
self.agents = agents
self.time = 0

spatial_dim = (map.shape[0], map.shape[1])
self.grid_dims = (map.shape[0], map.shape[1])
self.map = map
self.grid_size = spatial_dim[0] # Grid shape should be square

self.general_positions = {
agent: np.argwhere(map == chr(ord("A") + i))[0]
Expand Down Expand Up @@ -57,11 +56,11 @@ def __init__(self, map: np.ndarray, agents: List[str]):
##########
# Spaces #
##########
grid_multi_binary = gym.spaces.MultiBinary(spatial_dim)
grid_multi_binary = gym.spaces.MultiBinary(self.grid_dims)
self.observation_space = gym.spaces.Dict(
{
"army": gym.spaces.Box(
low=0, high=1e5, shape=spatial_dim, dtype=np.float32
low=0, high=1e5, shape=self.grid_dims, dtype=np.float32
),
"general": grid_multi_binary,
"city": grid_multi_binary,
Expand All @@ -71,7 +70,7 @@ def __init__(self, map: np.ndarray, agents: List[str]):
"visibile_cells": grid_multi_binary,
"structure": grid_multi_binary,
"action_mask": gym.spaces.MultiBinary(
(spatial_dim[0], spatial_dim[1], 4)
(self.grid_dims[0], self.grid_dims[1], 4)
),
"owned_land_count": gym.spaces.Discrete(np.iinfo(np.int64).max),
"owned_army_count": gym.spaces.Discrete(np.iinfo(np.int64).max),
Expand All @@ -83,7 +82,7 @@ def __init__(self, map: np.ndarray, agents: List[str]):
)

self.action_space = gym.spaces.MultiDiscrete(
[2, self.grid_size, self.grid_size, 4, 2]
[2, self.grid_dims[0], self.grid_dims[1], 4, 2]
)

def action_mask(self, agent: str) -> np.ndarray:
Expand All @@ -106,7 +105,9 @@ def action_mask(self, agent: str) -> np.ndarray:
ownership_channel = self.channels[f"ownership_{agent}"]
more_than_1_army = (self.channels["army"] > 1) * ownership_channel
owned_cells_indices = self.channel_to_indices(more_than_1_army)
valid_action_mask = np.zeros((self.grid_size, self.grid_size, 4), dtype=bool)
valid_action_mask = np.zeros(
(self.grid_dims[0], self.grid_dims[1], 4), dtype=bool
)

if self.is_done():
return valid_action_mask
Expand All @@ -116,8 +117,11 @@ def action_mask(self, agent: str) -> np.ndarray:

# check if destination is in grid bounds
in_first_boundary = np.all(destinations >= 0, axis=1)
in_second_boundary = np.all(destinations < self.grid_size, axis=1)
destinations = destinations[in_first_boundary & in_second_boundary]
in_height_boundary = destinations[:, 0] < self.grid_dims[0]
in_width_boundary = destinations[:, 1] < self.grid_dims[1]
destinations = destinations[
in_first_boundary & in_height_boundary & in_width_boundary
]

# check if destination is road
passable_cell_indices = (
Expand Down Expand Up @@ -170,7 +174,8 @@ def step(self, actions: Dict[str, np.ndarray]):
if self.action_mask(agent)[i, j, direction] == 0:
warnings.warn(
f"The submitted move byt agent {agent} does not take effect.\
Probably because you submitted an invalid move.", UserWarning
Probably because you submitted an invalid move.",
UserWarning,
)
continue
if split_army == 1: # Agent wants to split the army
Expand Down
2 changes: 1 addition & 1 deletion generals/integrations/pettingzoo_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import pettingzoo
from copy import deepcopy
from ..game import Game
from ..agents import Agent
from ..replay import Replay
from ..rendering import Renderer
from collections import OrderedDict
from typing import Dict
from ..agent import Agent


class PZ_Generals(pettingzoo.ParallelEnv):
Expand Down
Loading

0 comments on commit d40fcf6

Please sign in to comment.