Skip to content

Commit

Permalink
Merge pull request #83 from strakam/abstract-envs
Browse files Browse the repository at this point in the history
Wrappers and structure
  • Loading branch information
strakam authored Oct 12, 2024
2 parents 4eeb059 + acf8fd5 commit 620fac9
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 143 deletions.
3 changes: 2 additions & 1 deletion examples/complete_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
env = gym.make(
"gym-generals-v0", # Environment name
grid_factory=grid_factory, # Grid factory
agent=agent, # Your agent (used to get metadata like name and color)
npc=npc, # NPC that will play against the agent
agent_id="Agent", # Agent ID
agent_color=(67, 70, 86), # Agent color
render_mode="human", # "human" mode is for rendering, None is for no rendering
)

Expand Down
12 changes: 3 additions & 9 deletions examples/gymnasium_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@
from generals import AgentFactory

# Initialize agents
agent = AgentFactory.make_agent("expander")
npc = AgentFactory.make_agent("random")

env = gym.make(
"gym-generals-v0",
agent=agent,
npc=npc,
render_mode="human",
)
# Create environment
env = gym.make("gym-generals-v0", npc=npc, render_mode="human")

observation, info = env.reset()

terminated = truncated = False
while not (terminated or truncated):
action = agent.act(observation)
action = env.action_space.sample() # Here you put your agent's action
observation, reward, terminated, truncated, info = env.step(action)
env.render()
9 changes: 4 additions & 5 deletions examples/pettingzoo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
expander = AgentFactory.make_agent("expander")

agents = {
random.name: random,
expander.name: expander,
random.id: random,
expander.id: expander,
} # Environment calls agents by name

# Create environment -- render modes: {None, "human"}
# env = pz_generals(agents=agents, render_mode="human")D
env = gym.make("pz-generals-v0", agents=agents, render_mode="human")
env = gym.make("pz-generals-v0", agents=list(agents.keys()), render_mode="human")
observations, info = env.reset()

done = False
Expand All @@ -25,4 +24,4 @@
# All agents perform their actions
observations, rewards, terminated, truncated, info = env.step(actions)
done = any(terminated.values()) or any(truncated.values())
env.render(fps=6)
env.render()
3 changes: 2 additions & 1 deletion examples/record_replay_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
env = gym.make(
"gym-generals-v0", # Environment name
grid_factory=grid_factory, # Grid factory
agent=agent, # Your agent (used to get metadata like name and color)
agent_id="Agent", # Agent ID
agent_color=(67, 70, 86), # Agent color
npc=npc, # NPC that will play against the agent
)

Expand Down
14 changes: 3 additions & 11 deletions generals/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ class Agent(ABC):
Base class for all agents.
"""

def __init__(self, name="Agent", color=(67, 70, 86)):
self.name = name
def __init__(self, id="NPC", color=(67, 70, 86)):
self.id = id
self.color = color

@abstractmethod
Expand All @@ -27,12 +27,4 @@ def reset(self):
raise NotImplementedError

def __str__(self):
return self.name


class EmptyAgent(Agent):
def act(self, observation):
return None

def reset(self):
pass
return self.id
5 changes: 2 additions & 3 deletions generals/agents/agent_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .random_agent import RandomAgent
from .expander_agent import ExpanderAgent
from .agent import Agent, EmptyAgent
from .agent import Agent


class AgentFactory:
"""
Expand All @@ -19,7 +20,5 @@ def make_agent(agent_type: str, **kwargs) -> Agent:
return RandomAgent(**kwargs)
elif agent_type == "expander":
return ExpanderAgent(**kwargs)
elif agent_type == "empty":
return EmptyAgent(**kwargs)
else:
raise ValueError(f"Unknown agent type: {agent_type}")
8 changes: 5 additions & 3 deletions generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class ExpanderAgent(Agent):
def __init__(self, name="Expander", color=(0, 130, 255)):
super().__init__(name, color)
def __init__(self, id="Expander", color=(0, 130, 255)):
super().__init__(id, color)

def act(self, observation):
"""
Expand All @@ -30,7 +30,9 @@ def act(self, observation):

directions = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]
for i, action in enumerate(valid_actions):
di, dj = action[:-1] + directions[action[-1]].value # Destination cell indices
di, dj = (
action[:-1] + directions[action[-1]].value
) # Destination cell indices
if army[action[0], action[1]] <= army[di, dj] + 1: # Can't capture
continue
elif opponent[di, dj]:
Expand Down
4 changes: 2 additions & 2 deletions generals/agents/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

class RandomAgent(Agent):
def __init__(
self, name="Random", color=(242, 61, 106), split_prob=0.25, idle_prob=0.05
self, id="Random", color=(242, 61, 106), split_prob=0.25, idle_prob=0.05
):
super().__init__(name, color)
super().__init__(id, color)

self.idle_probability = idle_prob
self.split_probability = split_prob
Expand Down
26 changes: 15 additions & 11 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Action: TypeAlias = gym.Space
Info: TypeAlias = dict[str, Any]

increment_rate = 50

DIRECTIONS = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]


Expand All @@ -33,12 +33,16 @@ def __init__(self, grid: Grid, agents: list[str]):

self.channels = Channels(map, self.agents)

# Constants
self.increment_rate = 50
self.max_army_value = 10_000
self.max_land_value = np.prod(self.grid_dims)
self.max_timestep = 100_000
##########
# Spaces #
##########
max_value = 100_000
grid_multi_binary = gym.spaces.MultiBinary(self.grid_dims)
grid_discrete = np.ones(self.grid_dims, dtype=int) * max_value
grid_discrete = np.ones(self.grid_dims, dtype=int) * self.max_army_value
self.observation_space = gym.spaces.Dict(
{
"observation": gym.spaces.Dict(
Expand All @@ -51,12 +55,12 @@ def __init__(self, grid: Grid, agents: list[str]):
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"structure": grid_multi_binary,
"owned_land_count": gym.spaces.Discrete(max_value),
"owned_army_count": gym.spaces.Discrete(max_value),
"opponent_land_count": gym.spaces.Discrete(max_value),
"opponent_army_count": gym.spaces.Discrete(max_value),
"owned_land_count": gym.spaces.Discrete(self.max_army_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_army_value),
"opponent_army_count": gym.spaces.Discrete(self.max_army_value),
"is_winner": gym.spaces.Discrete(2),
"timestep": gym.spaces.Discrete(max_value),
"timestep": gym.spaces.Discrete(self.max_timestep),
}
),
"action_mask": gym.spaces.MultiBinary(self.grid_dims + (4,)),
Expand Down Expand Up @@ -238,7 +242,7 @@ def get_all_observations(self) -> dict[str, Observation]:
"""
Returns observations for all agents.
"""
return {agent: self._agent_observation(agent) for agent in self.agents}
return {agent: self.agent_observation(agent) for agent in self.agents}

def _global_game_update(self) -> None:
"""
Expand All @@ -248,7 +252,7 @@ def _global_game_update(self) -> None:
owners = self.agents

# every `increment_rate` steps, increase army size in each cell
if self.time % increment_rate == 0:
if self.time % self.increment_rate == 0:
for owner in owners:
self.channels.army += self.channels.ownership[owner]

Expand Down Expand Up @@ -284,7 +288,7 @@ def get_infos(self) -> dict[str, Info]:
}
return players_stats

def _agent_observation(self, agent: str) -> Observation:
def agent_observation(self, agent: str) -> Observation:
"""
Returns an observation for a given agent.
Args:
Expand Down
33 changes: 17 additions & 16 deletions generals/envs/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .gymnasium_integration import Gym_Generals, RewardFn
from .pettingzoo_integration import PZ_Generals
from .gymnasium_generals import GymnasiumGenerals
from .pettingzoo_generals import PettingZooGenerals
from generals.agents import Agent, AgentFactory

from generals import GridFactory
Expand All @@ -17,35 +17,36 @@

def pz_generals_v0(
grid_factory: GridFactory = GridFactory(),
agents: dict[str, Agent] = None,
reward_fn: RewardFn = None,
agents: list[str] = None,
render_mode=None,
):
assert len(agents) == 2, "Only 2 agents are supported in PZ_Generals."
env = PZ_Generals(
assert len(agents) == 2, "For now, only 2 agents are supported in PZ_Generals."
env = PettingZooGenerals(
grid_factory=grid_factory,
agents=agents,
reward_fn=reward_fn,
render_mode=render_mode,
)
return env


def gym_generals_v0(
grid_factory: GridFactory = GridFactory(),
agent: Agent = None,
npc: Agent = None,
reward_fn: RewardFn = None,
render_mode=None,
reward_fn=None,
agent_id: str = "Agent",
agent_color: tuple[int, int, int] = (67, 70, 86),
):
if npc is None:
print("No npc provided, using RandomAgent.")
npc = AgentFactory.init_agent("random")
env = Gym_Generals(
if not isinstance(npc, Agent):
print(
"NPC must be an instance of Agent class, Creating random NPC as a fallback."
)
npc = AgentFactory.make_agent("random")
env = GymnasiumGenerals(
grid_factory=grid_factory,
agent=agent,
npc=npc,
reward_fn=reward_fn,
render_mode=render_mode,
agent_id=agent_id,
agent_color=agent_color,
reward_fn=reward_fn,
)
return env
Loading

0 comments on commit 620fac9

Please sign in to comment.