Skip to content

Commit

Permalink
refactor: Improve some imports, remove whitespace
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 14, 2024
1 parent 30a41ef commit 90a0830
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: pre-commit hooks
name: CI

on:
push:
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
## **Generals.io RL**

[![CodeQL](https://github.com/strakam/Generals-RL/actions/workflows/codeql.yml/badge.svg)](https://github.com/strakam/Generals-RL/actions/workflows/codeql.yml)
[![CI](https://github.com/strakam/Generals-RL/actions/workflows/tests.yml/badge.svg)](https://github.com/strakam/Generals-RL/actions/workflows/tests.yml)
[![CI](https://github.com/strakam/Generals-RL/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/strakam/Generals-RL/actions/workflows/pre-commit.yml)



Expand Down Expand Up @@ -173,7 +173,7 @@ The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N
| `is_winner` || Indicates whether the agent won |
| `timestep` || Current timestep of the game |

The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell
The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell
`[i, j]` in one of four directions: `0 (up)`, `1 (down)`, `2 (left)`, or `3 (right)`.

### ⚡ Action
Expand All @@ -191,7 +191,8 @@ Actions are in a `dict` format with the following `key - value` format:
> ```
### 🎁 Reward
It is possible to implement custom reward function. The default is `1` for winner and `-1` for loser, otherwise `0`.
It is possible to implement custom reward function. The default reward is awarded only at the end of a game
and gives `1` for winner and `-1` for loser, otherwise `0`.
```python
def custom_reward_fn(observation, action, done, info):
# Give agent a reward based on the number of cells they own
Expand Down
6 changes: 3 additions & 3 deletions examples/gymnasium_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from generals import AgentFactory

# Initialize opponent agent ("random" or "expander")
npc = AgentFactory.make_agent("random")
# Initialize opponent agent
npc = AgentFactory.make_agent("expander")

# Create environment
env = gym.make("gym-generals-v0", npc=npc, render_mode="human")

observation, info = env.reset()
terminated = truncated = False
while not (terminated or truncated):
action = env.action_space.sample() # Here you put your agent's action
action = env.action_space.sample() # Here you put an action of your agent
observation, reward, terminated, truncated, info = env.step(action)
env.render()
9 changes: 5 additions & 4 deletions examples/pettingzoo_example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from generals.agents import AgentFactory
from generals.agents import RandomAgent, ExpanderAgent
from generals.envs import PettingZooGenerals

# Initialize agents
random = AgentFactory.make_agent("random")
expander = AgentFactory.make_agent("expander")
random = RandomAgent()
expander = ExpanderAgent()

# Store agents in a dictionary
agents = {
random.id: random,
expander.id: expander,
}
agent_ids = list(agents.keys()) # Environment calls agents by name
agent_ids = list(agents.keys())

# Create environment
env = PettingZooGenerals(agents=agent_ids, render_mode="human")
Expand Down
2 changes: 1 addition & 1 deletion generals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from gymnasium.envs.registration import register

from generals.agents.agent_factory import AgentFactory
from generals.agents import AgentFactory
from generals.core.grid import Grid, GridFactory
from generals.core.replay import Replay
from generals.envs.pettingzoo_generals import PettingZooGenerals
Expand Down
10 changes: 9 additions & 1 deletion generals/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

from .agent import Agent
from .agent_factory import AgentFactory
from .expander_agent import ExpanderAgent
from .random_agent import RandomAgent

# You can also define an __all__ list if you want to restrict what gets imported with *
__all__ = ["Agent", "AgentFactory"]
__all__ = [
"Agent",
"AgentFactory",
"RandomAgent",
"ExpanderAgent",
"AgentFactory",
]
5 changes: 2 additions & 3 deletions generals/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gymnasium as gym
import numpy as np

# Type aliases
Observation: TypeAlias = dict[str, np.ndarray | dict[str, gym.Space]]
Action: TypeAlias = dict[str, int | np.ndarray]
Info: TypeAlias = dict[str, Any]
Expand All @@ -13,9 +14,7 @@
RewardFn: TypeAlias = Callable[[Observation, Action, bool, Info], Reward]
AgentID: TypeAlias = str

#################
# Game Literals #
#################
# Game Literals
PASSABLE: Literal["."] = "."
MOUNTAIN: Literal["#"] = "#"

Expand Down

0 comments on commit 90a0830

Please sign in to comment.