Skip to content

Commit

Permalink
Merge pull request #63 from Puckoland/gymReplayFix
Browse files Browse the repository at this point in the history
Gym replay fix
  • Loading branch information
strakam authored Oct 1, 2024
2 parents d8999e1 + 71e72ab commit 426b6ef
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 24 deletions.
3 changes: 2 additions & 1 deletion generals/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from scipy.ndimage import maximum_filter

Observation: TypeAlias = dict[str, gym.Space | dict[str, gym.Space]]
Action: TypeAlias = gym.Space
Info: TypeAlias = dict[str, Any]


Expand Down Expand Up @@ -166,7 +167,7 @@ def visibility_channel(self, ownership_channel: np.ndarray) -> np.ndarray:
"""
return maximum_filter(ownership_channel, size=3)

def step(self, actions: dict[str, gym.spaces.Tuple]) -> dict[str, gym.spaces.Dict]:
def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict[str, dict]]:
"""
Perform one step of the game
Expand Down
23 changes: 13 additions & 10 deletions generals/integrations/gymnasium_integration.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections.abc import Callable
from typing import TypeAlias
from typing import TypeAlias, Any, SupportsFloat

import gymnasium as gym
import functools
from copy import deepcopy

from ..agents import Agent
from ..game import Game, Observation
from ..game import Game, Action, Observation
from ..grid import GridFactory
from ..gui import GUI
from ..replay import Replay

# Type aliases
from generals.game import Info
Reward: TypeAlias = float
RewardFn: TypeAlias = Callable[[dict[str, Observation]], Reward]
RewardFn: TypeAlias = Callable[[dict[str, Observation], Action, bool, Info], Reward]


class Gym_Generals(gym.Env):
Expand Down Expand Up @@ -45,18 +46,20 @@ def __init__(
self.action_space = game.action_space

@functools.lru_cache(maxsize=None)
def observation_space(self):
def observation_space(self) -> gym.Space:
return self.game.observation_space

@functools.lru_cache(maxsize=None)
def action_space(self):
def action_space(self) -> gym.Space:
return self.game.action_space

def render(self, fps: int = 6) -> None:
if self.render_mode == "human":
self.gui.tick(fps=fps)

def reset(self, seed=None, options=None):
def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[
Observation, dict[str, Any]
]:
if options is None:
options = {}
super().reset(seed=seed)
Expand All @@ -79,7 +82,7 @@ def reset(self, seed=None, options=None):
self.replay = Replay(
name=options["replay_file"],
grid=grid,
agent_data=self.agent_colors,
agent_data=self.agent_data,
)
self.replay.add_state(deepcopy(self.game.channels))
elif hasattr(self, "replay"):
Expand All @@ -89,7 +92,7 @@ def reset(self, seed=None, options=None):
info = {}
return observation, info

def step(self, action):
def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]:
# get action of NPC
npc_action = self.npc.play(self.game._agent_observation(self.npc.name))
actions = {self.agent_name: action, self.npc.name: npc_action}
Expand All @@ -107,13 +110,13 @@ def step(self, action):

if terminated:
if hasattr(self, "replay"):
self.replay.save()
self.replay.store()

return observation, reward, terminated, truncated, info

def default_reward(
self, observation: dict[str, Observation],
action: gym.Space,
action: Action,
done: bool,
info: Info,
) -> Reward:
Expand Down
33 changes: 21 additions & 12 deletions generals/integrations/pettingzoo_integration.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import functools
from collections.abc import Callable
from typing import TypeAlias
from typing import TypeAlias, Any

import pettingzoo
from gymnasium import spaces
from copy import deepcopy
from ..game import Game, Observation

from pettingzoo.utils.env import AgentID

from ..game import Game, Action, Observation
from ..grid import GridFactory
from ..agents import Agent
from ..gui import GUI
Expand All @@ -15,8 +18,8 @@
# Type aliases
from generals.game import Info

Reward: TypeAlias = dict[str, float]
RewardFn: TypeAlias = Callable[[dict[str, Observation]], Reward]
Reward: TypeAlias = float
RewardFn: TypeAlias = Callable[[dict[str, Observation], Action, bool, Info], Reward]


class PZ_Generals(pettingzoo.ParallelEnv):
Expand All @@ -43,20 +46,24 @@ def __init__(
self.reward_fn = self.default_reward if reward_fn is None else reward_fn

@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
def observation_space(self, agent: AgentID) -> spaces.Space:
assert agent in self.possible_agents, f"Agent {agent} not in possible agents"
return self.game.observation_space

@functools.lru_cache(maxsize=None)
def action_space(self, agent):
def action_space(self, agent: AgentID) -> spaces.Space:
assert agent in self.possible_agents, f"Agent {agent} not in possible agents"
return self.game.action_space

def render(self, fps=6):
def render(self, fps=6) -> None:
if self.render_mode == "human":
self.gui.tick(fps=fps)

def reset(self, seed=None, options={}):
def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[
dict[AgentID, Observation], dict[AgentID, dict]
]:
if options is None:
options = {}
self.agents = deepcopy(self.possible_agents)

if "grid" in options:
Expand All @@ -83,8 +90,10 @@ def reset(self, seed=None, options={}):
infos = {agent: {} for agent in self.agents}
return observations, infos

def step(self, action):
observations, infos = self.game.step(action)
def step(self, actions: dict[AgentID, Action]) -> tuple[
dict[AgentID, Observation], dict[AgentID, float], dict[AgentID, bool], dict[AgentID, bool], dict[AgentID, Info]
]:
observations, infos = self.game.step(actions)

truncated = {agent: False for agent in self.agents} # no truncation
terminated = {
Expand All @@ -94,7 +103,7 @@ def step(self, action):
rewards = {
agent: self.reward_fn(
observations[agent],
action,
actions[agent],
terminated[agent] or truncated[agent],
infos[agent],
)
Expand All @@ -116,7 +125,7 @@ def step(self, action):
def default_reward(
self,
observation: dict[str, Observation],
action: spaces.Tuple,
action: Action,
done: bool,
info: Info,
) -> Reward:
Expand Down
1 change: 0 additions & 1 deletion generals/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time

from generals.grid import Grid
from generals.rendering import Renderer
from generals.gui import GUI
from generals.game import Game
from copy import deepcopy
Expand Down

0 comments on commit 426b6ef

Please sign in to comment.