Skip to content

Commit

Permalink
Merge pull request #28 from strakam/heuristic_agent
Browse files Browse the repository at this point in the history
Heuristic agent
  • Loading branch information
strakam authored Sep 23, 2024
2 parents 1ea0e30 + 553b243 commit 0309922
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 13 deletions.
10 changes: 5 additions & 5 deletions examples/pettingzoo_example.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from generals.env import pz_generals
from generals.agents import RandomAgent
from generals.agents import ExpanderAgent, RandomAgent
from generals.config import GameConfig

# Initialize agents - their names are then called for actions
agents = {
"Red": RandomAgent("Red"),
"Blue": RandomAgent("Blue")
"Random": RandomAgent("Random"),
"Expander": ExpanderAgent("Expander")
}

game_config = GameConfig(
grid_size=16,
mountain_density=0.2,
city_density=0.05,
general_positions=[(2, 12), (8, 9)],
general_positions=[(4, 12), (12, 4)],
agent_names=list(agents.keys()),
)

Expand All @@ -21,7 +21,7 @@
observations, info = env.reset(options={"replay_file": "test"})

# How fast we want rendering to be
actions_per_second = 2
actions_per_second = 6

while not env.game.is_done():
actions = {}
Expand Down
56 changes: 56 additions & 0 deletions generals/agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from generals.config import DIRECTIONS

class Agent:
"""
Base class for all agents.
Expand Down Expand Up @@ -32,3 +34,57 @@ def play(self, observation):
# append 1 or 0 randomly to the action (to say whether to send half of troops or all troops)
action = np.append(valid_actions[action_index], np.random.choice([0, 1]))
return action

class ExpanderAgent(Agent):
def __init__(self, name):
super().__init__(name)

def play(self, observation):
"""
Heuristically selects a valid (expanding) action.
Prioritizes capturing opponent and then neutral cells.
"""
mask = observation["action_mask"]
army = observation["army"]

valid_actions = np.argwhere(mask == 1)
actions_with_more_than_1_army = (
army[valid_actions[:, 0], valid_actions[:, 1]] > 1
)
if np.sum(actions_with_more_than_1_army) == 0:
return [-1, -1, 0, 0] # IDLE move

valid_actions = valid_actions[actions_with_more_than_1_army]

opponent = observation["ownership_opponent"]
neutral = observation["ownership_neutral"]

# find actions that capture opponent or neutral cells
actions_to_opponent = np.zeros(len(valid_actions))
actions_to_neutral = np.zeros(len(valid_actions))
for i, action in enumerate(valid_actions):
destination = action[:-1] + DIRECTIONS[action[-1]]
if army[action[0], action[1]] <= army[destination[0], destination[1]] + 1:
continue
elif opponent[destination[0], destination[1]]:
actions_to_opponent[i] = 1
if neutral[destination[0], destination[1]]:
actions_to_neutral[i] = 1

actions_to_neutral_indices = np.argwhere(actions_to_neutral == 1).flatten()
actions_to_opponent_indices = np.argwhere(actions_to_opponent == 1).flatten()
if len(actions_to_opponent_indices) > 0:
# pick random action that captures an opponent cell
action_index = np.random.choice(len(actions_to_opponent_indices))
action = valid_actions[actions_to_opponent_indices[action_index]]
elif len(actions_to_neutral_indices) > 0:
# or pick random action that captures a neutral cell
action_index = np.random.choice(len(actions_to_neutral_indices))
action = valid_actions[actions_to_neutral_indices[action_index]]
else: # otherwise pick a random action
action_index = np.random.choice(len(valid_actions))
action = valid_actions[action_index]

# append 0 to the action (to send all available troops)
action = np.append(action, 0)
return action
1 change: 1 addition & 0 deletions generals/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class GameConfig(BaseModel):
DOWN: List[int] = [1, 0]
LEFT: List[int] = [0, -1]
RIGHT: List[int] = [0, 1]
DIRECTIONS: List[List[int]] = [UP, DOWN, LEFT, RIGHT]

##################
# Game constants #
Expand Down
11 changes: 4 additions & 7 deletions generals/game.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import numpy as np
import gymnasium as gym
from typing import Dict, List
from generals.config import PASSABLE, MOUNTAIN
from generals.config import UP, DOWN, LEFT, RIGHT
from generals.config import INCREMENT_RATE
from generals.config import DIRECTIONS, PASSABLE, MOUNTAIN, INCREMENT_RATE

from scipy.ndimage import maximum_filter

Expand Down Expand Up @@ -109,7 +107,7 @@ def action_mask(self, agent: str) -> np.ndarray:
if self.is_done():
return valid_action_mask

for channel_index, direction in enumerate([UP, DOWN, LEFT, RIGHT]):
for channel_index, direction in enumerate(DIRECTIONS):
destinations = owned_cells_indices + direction

# check if destination is in grid bounds
Expand Down Expand Up @@ -158,7 +156,6 @@ def step(self, actions: Dict[str, np.ndarray]):
actions: dictionary of agent name to action
"""
done_before_actions = self.is_done()
directions = np.array([UP, DOWN, LEFT, RIGHT])

# Agent with smaller army to move is prioritized
armies = [
Expand All @@ -178,8 +175,8 @@ def step(self, actions: Dict[str, np.ndarray]):

si, sj = source[0], source[1] # source indices
di, dj = (
source[0] + directions[direction][0],
source[1] + directions[direction][1],
source[0] + DIRECTIONS[direction][0],
source[1] + DIRECTIONS[direction][1],
) # destination indices

send_half = actions[agent][3]
Expand Down
2 changes: 1 addition & 1 deletion generals/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def handle_events(self):
if event.type == pygame.KEYDOWN and self.from_replay:
# Speed up game right arrow is pressed
if event.key == pygame.K_RIGHT and not self.paused:
self.game_speed = max(1 / 16, self.game_speed / 2)
self.game_speed = max(1 / 128, self.game_speed / 2)
# Slow down game left arrow is pressed
if event.key == pygame.K_LEFT and not self.paused:
self.game_speed = min(32, self.game_speed * 2)
Expand Down

0 comments on commit 0309922

Please sign in to comment.