Skip to content

Commit

Permalink
Merge pull request #76 from strakam/refactors
Browse files Browse the repository at this point in the history
refactor: Include enums, remove unused things
  • Loading branch information
strakam authored Oct 3, 2024
2 parents 9af3228 + c3ef39b commit 136cc61
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 109 deletions.
6 changes: 4 additions & 2 deletions generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from .agent import Agent

from generals.core.config import DIRECTIONS
from generals.core.config import Direction


class ExpanderAgent(Agent):
Expand All @@ -27,8 +27,10 @@ def play(self, observation):
# Find actions that capture opponent or neutral cells
actions_capture_opponent = np.zeros(len(valid_actions))
actions_capture_neutral = np.zeros(len(valid_actions))

directions = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]
for i, action in enumerate(valid_actions):
di, dj = action[:-1] + DIRECTIONS[action[-1]] # Destination cell indices
di, dj = action[:-1] + directions[action[-1]].value # Destination cell indices
if army[action[0], action[1]] <= army[di, dj] + 1: # Can't capture
continue
elif opponent[di, dj]:
Expand Down
76 changes: 19 additions & 57 deletions generals/core/config.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,33 @@
from typing import Literal
from importlib.resources import files
from enum import Enum, IntEnum, StrEnum

#################
# Game Literals #
#################
PASSABLE: Literal['.'] = '.'
MOUNTAIN: Literal['#'] = '#'
CITY: Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] = 0 # CITY can be any digit 0-9
PASSABLE: Literal["."] = "."
MOUNTAIN: Literal["#"] = "#"
CITY: Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] = 0 # CITY can be any digit 0-9

#########
# Moves #
#########
UP: tuple[int, int] = (-1, 0)
DOWN: tuple[int, int] = (1, 0)
LEFT: tuple[int, int] = (0, -1)
RIGHT: tuple[int, int] = (0, 1)
DIRECTIONS: list[tuple[int, int]] = [UP, DOWN, LEFT, RIGHT]

##################
# Game constants #
##################
GAME_SPEED: float = 8 # by default, every 8 ticks, actions are processed
class Dimension(IntEnum):
SQUARE_SIZE = 50
GUI_CELL_HEIGHT = 30
GUI_CELL_WIDTH = 70
MINIMUM_WINDOW_SIZE = 700

########################
# Grid visual settings #
########################
SQUARE_SIZE: int = 50
LINE_WIDTH: int = 1
GUI_ROW_HEIGHT: int = 30
GUI_CELL_WIDTH: int = 70
MINIMUM_WINDOW_SIZE: int = 700

##########
# Colors #
##########
FOG_OF_WAR: tuple[int, int, int] = (70, 73, 76)
NEUTRAL_CASTLE: tuple[int, int, int] = (128, 128, 128)
VISIBLE_PATH: tuple[int, int, int] = (200, 200, 200)
VISIBLE_MOUNTAIN: tuple[int, int, int] = (187, 187, 187)
BLACK: tuple[int, int, int] = (0, 0, 0)
WHITE: tuple[int, int, int] = (230, 230, 230)
PLAYER_1_COLOR: tuple[int, int, int] = (255, 0, 0)
PLAYER_2_COLOR: tuple[int, int, int] = (67, 99, 216)
PLAYER_COLORS: list[tuple[int, int, int]] = [PLAYER_1_COLOR, PLAYER_2_COLOR]
class Direction(Enum):
UP = (-1, 0)
DOWN = (1, 0)
LEFT = (0, -1)
RIGHT = (0, 1)

#########
# Fonts #
#########
FONT_TYPE = "Quicksand-Medium.ttf" # Font options are Quicksand-SemiBold.ttf, Quicksand-Medium.ttf, Quicksand-Light.ttf
FONT_SIZE = 18
try:
file_ref = files("generals.assets.fonts") / FONT_TYPE
FONT_PATH = str(file_ref)
except FileNotFoundError:
raise FileNotFoundError(f"Font file {FONT_TYPE} not found in the fonts directory")

#########
# Icons #
#########
try:
class Path(StrEnum):
GENERAL_PATH = str(files("generals.assets.images") / "crownie.png")
except FileNotFoundError:
raise FileNotFoundError("Image not found")
try:
CITY_PATH = str(files("generals.assets.images") / "citie.png")
except FileNotFoundError:
raise FileNotFoundError("Image not found")
try:
MOUNTAIN_PATH = str(files("generals.assets.images") / "mountainie.png")
except FileNotFoundError:
raise FileNotFoundError("Image not found")

# Font options are Quicksand-SemiBold.ttf, Quicksand-Medium.ttf, Quicksand-Light.ttf
FONT_PATH = str(files("generals.assets.fonts") / "Quicksand-Medium.ttf")
11 changes: 6 additions & 5 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .channels import Channels
from .grid import Grid
from .config import DIRECTIONS
from .config import Direction

from scipy.ndimage import maximum_filter

Expand All @@ -16,6 +16,7 @@
Info: TypeAlias = dict[str, Any]

increment_rate = 50
DIRECTIONS = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]


class Game:
Expand Down Expand Up @@ -99,7 +100,7 @@ def action_mask(self, agent: str) -> np.ndarray:
return valid_action_mask

for channel_index, direction in enumerate(DIRECTIONS):
destinations = owned_cells_indices + direction
destinations = owned_cells_indices + direction.value

# check if destination is in grid bounds
in_first_boundary = np.all(destinations >= 0, axis=1)
Expand All @@ -116,7 +117,7 @@ def action_mask(self, agent: str) -> np.ndarray:
action_destinations = destinations[passable_cell_indices]

# get valid action mask for a given direction
valid_source_indices = action_destinations - direction
valid_source_indices = action_destinations - direction.value
valid_action_mask[
valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index
] = 1.0
Expand Down Expand Up @@ -187,8 +188,8 @@ def step(
continue

di, dj = (
si + DIRECTIONS[direction][0],
sj + DIRECTIONS[direction][1],
si + DIRECTIONS[direction].value[0],
sj + DIRECTIONS[direction].value[1],
) # destination indices

# Figure out the target square owner and army size
Expand Down
17 changes: 11 additions & 6 deletions generals/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

from generals.core.game import Game
from .properties import Properties
from .event_handler import TrainEventHandler, GameEventHandler, ReplayEventHandler, ReplayCommand
from .event_handler import (
TrainEventHandler,
GameEventHandler,
ReplayEventHandler,
Command,
)
from .rendering import Renderer


Expand All @@ -14,16 +19,16 @@ def __init__(
agent_data: dict[str, dict[str, Any]],
mode: Literal["train", "game", "replay"] = "train",
):
self.properties = Properties(game, agent_data, mode)
self.__renderer = Renderer(self.properties)
self.__event_handler = self.__initialize_event_handler()

pygame.init()
pygame.display.set_caption("Generals")

# Handle key repeats
pygame.key.set_repeat(500, 64)

self.properties = Properties(game, agent_data, mode)
self.__renderer = Renderer(self.properties)
self.__event_handler = self.__initialize_event_handler()

def __initialize_event_handler(self):
if self.properties.mode == "train":
return TrainEventHandler
Expand All @@ -32,7 +37,7 @@ def __initialize_event_handler(self):
elif self.properties.mode == "replay":
return ReplayEventHandler

def tick(self, fps=None):
def tick(self, fps=None) -> Command:
handler = self.__event_handler(self.properties)
command = handler.handle_events()
if command.quit:
Expand Down
15 changes: 11 additions & 4 deletions generals/gui/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from pygame.time import Clock

from generals.core import config as c
from generals.core.game import Game
from generals.core.config import Dimension


@dataclass
Expand All @@ -14,13 +14,16 @@ class Properties:
__mode: Literal["train", "game", "replay"]
__game_speed: int = 1
__clock: Clock = Clock()
__font_size = 18

def __post_init__(self):
self.__grid_height: int = self.__game.grid_dims[0]
self.__grid_width: int = self.__game.grid_dims[1]
self.__display_grid_width: int = c.SQUARE_SIZE * self.grid_width
self.__display_grid_height: int = c.SQUARE_SIZE * self.grid_height
self.__right_panel_width: int = 4 * c.GUI_CELL_WIDTH
self.__display_grid_width: int = Dimension.SQUARE_SIZE.value * self.grid_width
self.__display_grid_height: int = (
Dimension.SQUARE_SIZE.value * self.grid_height
)
self.__right_panel_width: int = 4 * Dimension.GUI_CELL_WIDTH.value

self.__paused: bool = False

Expand Down Expand Up @@ -85,6 +88,10 @@ def display_grid_height(self):
def right_panel_width(self):
return self.__right_panel_width

@property
def font_size(self):
return self.__font_size

def update_speed(self, multiplier: float) -> None:
"""multiplier: usually 2.0 or 0.5"""
new_speed = self.game_speed * multiplier
Expand Down
Loading

0 comments on commit 136cc61

Please sign in to comment.