Skip to content

Commit

Permalink
refactor: Make Observation a standalone object
Browse files Browse the repository at this point in the history
.. and make it compatible with observations from simulator and
generalsio
  • Loading branch information
strakam committed Oct 23, 2024
1 parent 64a10ef commit b58e4fc
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 117 deletions.
2 changes: 1 addition & 1 deletion generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def act(self, observation: Observation) -> Action:
"split": 0,
}

army = observation["army"]
army = observation["armies"]
opponent = observation["opponent_cells"]
neutral = observation["neutral_cells"]

Expand Down
7 changes: 7 additions & 0 deletions generals/core/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def get_visibility(self, agent_id: str) -> np.ndarray:
channel = self._ownership[agent_id]
return maximum_filter(channel, size=3)

@staticmethod
def channel_to_indices(channel: np.ndarray) -> np.ndarray:
"""
Returns a list of indices of cells with non-zero values from specified a channel.
"""
return np.argwhere(channel != 0)

@property
def ownership(self) -> dict[str, np.ndarray]:
return self._ownership
Expand Down
170 changes: 73 additions & 97 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(self, grid: Grid, agents: list[str]):
{
"observation": gym.spaces.Dict(
{
"army": gym.spaces.MultiDiscrete(grid_discrete),
"general": grid_multi_binary,
"city": grid_multi_binary,
"armies": gym.spaces.MultiDiscrete(grid_discrete),
"generals": grid_multi_binary,
"cities": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
Expand All @@ -52,7 +52,6 @@ def __init__(self, grid: Grid, agents: list[str]):
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_army_value),
"opponent_army_count": gym.spaces.Discrete(self.max_army_value),
"is_winner": gym.spaces.Discrete(2),
"timestep": gym.spaces.Discrete(self.max_timestep),
}
),
Expand All @@ -69,51 +68,46 @@ def __init__(self, grid: Grid, agents: list[str]):
}
)

def action_mask(self, agent: str) -> np.ndarray:
"""
Function to compute valid actions from a given ownership mask.
Valid action is an action that originates from agent's cell with atleast 2 units
and does not bump into a mountain or fall out of the grid.
Returns:
np.ndarray: an NxNx4 array, where each channel is a boolean mask
of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid.
I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j).
"""

ownership_channel = self.channels.ownership[agent]
more_than_1_army = (self.channels.army > 1) * ownership_channel
owned_cells_indices = self.channel_to_indices(more_than_1_army)
valid_action_mask = np.zeros((self.grid_dims[0], self.grid_dims[1], 4), dtype=bool)

if self.is_done() and not self.agent_won(agent): # if you lost, return all zeros
return valid_action_mask
# def action_mask(self, agent: str) -> np.ndarray:
# """
# Function to compute valid actions from a given ownership mask.
#
# Valid action is an action that originates from agent's cell with atleast 2 units
# and does not bump into a mountain or fall out of the grid.
# Returns:
# np.ndarray: an NxNx4 array, where each channel is a boolean mask
# of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid.
#
# I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j).
# """
#
# ownership_channel = self.channels.ownership[agent]
# more_than_1_army = (self.channels.army > 1) * ownership_channel
# owned_cells_indices = self.channel_to_indices(more_than_1_army)
# valid_action_mask = np.zeros((self.grid_dims[0], self.grid_dims[1], 4), dtype=bool)
#
# if self.is_done() and not self.agent_won(agent): # if you lost, return all zeros
# return valid_action_mask
#
# for channel_index, direction in enumerate(DIRECTIONS):
# destinations = owned_cells_indices + direction.value
#
# # check if destination is in grid bounds
# in_first_boundary = np.all(destinations >= 0, axis=1)
# in_height_boundary = destinations[:, 0] < self.grid_dims[0]
# in_width_boundary = destinations[:, 1] < self.grid_dims[1]
# destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary]
#
# # check if destination is road
# passable_cell_indices = self.channels.passable[destinations[:, 0], destinations[:, 1]] == 1
# action_destinations = destinations[passable_cell_indices]
#
# # get valid action mask for a given direction
# valid_source_indices = action_destinations - direction.value
# valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0
# # assert False
# return valid_action_mask

for channel_index, direction in enumerate(DIRECTIONS):
destinations = owned_cells_indices + direction.value

# check if destination is in grid bounds
in_first_boundary = np.all(destinations >= 0, axis=1)
in_height_boundary = destinations[:, 0] < self.grid_dims[0]
in_width_boundary = destinations[:, 1] < self.grid_dims[1]
destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary]

# check if destination is road
passable_cell_indices = self.channels.passable[destinations[:, 0], destinations[:, 1]] == 1
action_destinations = destinations[passable_cell_indices]

# get valid action mask for a given direction
valid_source_indices = action_destinations - direction.value
valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0
# assert False
return valid_action_mask

def channel_to_indices(self, channel: np.ndarray) -> np.ndarray:
"""
Returns a list of indices of cells with non-zero values from specified a channel.
"""
return np.argwhere(channel != 0)

def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict[str, Any]]:
"""
Expand All @@ -135,14 +129,6 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
# Skip if agent wants to pass the turn
if pass_turn == 1:
continue
# Skip if the move is invalid
if self.action_mask(agent)[i, j, direction] == 0:
warnings.warn(
f"The submitted move byt agent {agent} does not take effect.\
Probably because you submitted an invalid move.",
UserWarning,
)
continue
if split_army == 1: # Agent wants to split the army
army_to_move = self.channels.army[i, j] // 2
else: # Leave just one army in the source cell
Expand Down Expand Up @@ -199,16 +185,6 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
else:
self._global_game_update()

observations = self.get_all_observations()
infos: dict[str, Any] = {agent: {} for agent in self.agents}
return observations, infos

def get_all_observations(self) -> dict[str, Observation]:
"""
Returns observations for all agents.
"""
return {agent: self.agent_observation(agent) for agent in self.agents}

def _global_game_update(self) -> None:
"""
Update game state globally.
Expand Down Expand Up @@ -250,37 +226,37 @@ def get_infos(self) -> dict[str, Info]:
"is_winner": self.agent_won(agent),
}
return players_stats

def agent_observation(self, agent: str) -> Observation:
"""
Returns an observation for a given agent.
"""
info = self.get_infos()
opponent = self.agents[0] if agent == self.agents[1] else self.agents[1]
visible = self.channels.get_visibility(agent)
invisible = 1 - visible
_observation = {
"army": self.channels.army.astype(int) * visible,
"general": self.channels.general * visible,
"city": self.channels.city * visible,
"owned_cells": self.channels.ownership[agent] * visible,
"opponent_cells": self.channels.ownership[opponent] * visible,
"neutral_cells": self.channels.ownership_neutral * visible,
"visible_cells": visible,
"structures_in_fog": invisible * (self.channels.mountain + self.channels.city),
"owned_land_count": info[agent]["land"],
"owned_army_count": info[agent]["army"],
"opponent_land_count": info[opponent]["land"],
"opponent_army_count": info[opponent]["army"],
"is_winner": int(info[agent]["is_winner"]),
"timestep": self.time,
}
observation: Observation = {
"observation": _observation,
"action_mask": self.action_mask(agent),
}

return observation
#
# def agent_observation(self, agent: str) -> Observation:
# """
# Returns an observation for a given agent.
# """
# info = self.get_infos()
# opponent = self.agents[0] if agent == self.agents[1] else self.agents[1]
# visible = self.channels.get_visibility(agent)
# invisible = 1 - visible
# _observation = {
# "army": self.channels.army.astype(int) * visible,
# "general": self.channels.general * visible,
# "city": self.channels.city * visible,
# "owned_cells": self.channels.ownership[agent] * visible,
# "opponent_cells": self.channels.ownership[opponent] * visible,
# "neutral_cells": self.channels.ownership_neutral * visible,
# "visible_cells": visible,
# "structures_in_fog": invisible * (self.channels.mountain + self.channels.city),
# "owned_land_count": info[agent]["land"],
# "owned_army_count": info[agent]["army"],
# "opponent_land_count": info[opponent]["land"],
# "opponent_army_count": info[opponent]["army"],
# "is_winner": int(info[agent]["is_winner"]),
# "timestep": self.time,
# }
# observation: Observation = {
# "observation": _observation,
# "action_mask": self.action_mask(agent),
# }
#
# return observation

def agent_won(self, agent: str) -> bool:
"""
Expand Down
Loading

0 comments on commit b58e4fc

Please sign in to comment.