Skip to content

Commit

Permalink
type hints support
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Feb 8, 2024
1 parent 6bb3930 commit ef78964
Show file tree
Hide file tree
Showing 22 changed files with 160 additions and 102 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/codestyle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ jobs:
pip install -e ".[dev]"
- name: check codestyle
run: |
ruff --config pyproject.toml --diff .
ruff --config pyproject.toml --diff .
- name: check type hints
run: |
pyright --project=pyproject.toml src/xminigrid
11 changes: 9 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
repos:
# ruff checking
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.4
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
- id: ruff-format

# pyright checking
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.350
hooks:
- id: pyright
args: [--project=pyproject.toml]
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ classifiers = [
dependencies = [
"jax>=0.4.16",
"jaxlib>=0.4.16",
"flax>=0.7.0",
"flax>=0.8.0",
"rich>=13.4.2",
]

Expand Down Expand Up @@ -96,7 +96,7 @@ exclude = [
"**/__pycache__",
]

reportMissingImports = true
reportMissingImports = "none"
reportMissingTypeStubs = false
pythonVersion = "3.10"
pythonPlatform = "All"
2 changes: 1 addition & 1 deletion src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.5.1"
__version__ = "0.6.0"

# ---------- XLand-MiniGrid environments ----------

Expand Down
14 changes: 7 additions & 7 deletions src/xminigrid/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .core.rules import check_rule
from .rendering.rgb_render import render as rgb_render
from .rendering.text_render import render as text_render
from .types import IntOrArray, State, StepType, TimeStep
from .types import EnvCarryT, IntOrArray, State, StepType, TimeStep


class EnvParams(struct.PyTreeNode):
Expand All @@ -32,7 +32,7 @@ class EnvParams(struct.PyTreeNode):
EnvParamsT = TypeVar("EnvParamsT", bound="EnvParams")


class Environment(abc.ABC, Generic[EnvParamsT]):
class Environment(abc.ABC, Generic[EnvParamsT, EnvCarryT]):
@abc.abstractmethod
def default_params(self, **kwargs: Any) -> EnvParamsT:
...
Expand All @@ -48,10 +48,10 @@ def time_limit(self, params: EnvParamsT) -> int:
return 3 * params.height * params.width

@abc.abstractmethod
def _generate_problem(self, params: EnvParamsT, key: jax.Array) -> State:
def _generate_problem(self, params: EnvParamsT, key: jax.Array) -> State[EnvCarryT]:
...

def reset(self, params: EnvParamsT, key: jax.Array) -> TimeStep:
def reset(self, params: EnvParamsT, key: jax.Array) -> TimeStep[EnvCarryT]:
state = self._generate_problem(params, key)
timestep = TimeStep(
state=state,
Expand All @@ -63,7 +63,7 @@ def reset(self, params: EnvParamsT, key: jax.Array) -> TimeStep:
return timestep

# Why timestep + state at once, and not like in Jumanji? To be able to do autoresets in gym and envpools styles
def step(self, params: EnvParamsT, timestep: TimeStep, action: IntOrArray) -> TimeStep:
def step(self, params: EnvParamsT, timestep: TimeStep[EnvCarryT], action: IntOrArray) -> TimeStep[EnvCarryT]:
new_grid, new_agent, changed_position = take_action(timestep.state.grid, timestep.state.agent, action)
new_grid, new_agent = check_rule(timestep.state.rule_encoding, new_grid, new_agent, action, changed_position)

Expand Down Expand Up @@ -92,9 +92,9 @@ def step(self, params: EnvParamsT, timestep: TimeStep, action: IntOrArray) -> Ti
)
return timestep

def render(self, params: EnvParamsT, timestep: TimeStep) -> np.ndarray | str:
def render(self, params: EnvParamsT, timestep: TimeStep[EnvCarryT]) -> np.ndarray | str:
if params.render_mode == "rgb_array":
return rgb_render(timestep.state.grid, timestep.state.agent, params.view_size)
return rgb_render(np.asarray(timestep.state.grid), timestep.state.agent, params.view_size)
elif params.render_mode == "rich_text":
return text_render(timestep.state.grid, timestep.state.agent)
else:
Expand Down
6 changes: 4 additions & 2 deletions src/xminigrid/envs/minigrid/blockedunlockpickup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -31,7 +33,7 @@
_rule_encoding = EmptyRule().encode()[None, ...]


class BlockedUnlockPickUp(Environment[EnvParams]):
class BlockedUnlockPickUp(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=6, width=11)
default_params = default_params.replace(**kwargs)
Expand All @@ -40,7 +42,7 @@ def default_params(self, **kwargs) -> EnvParams:
def time_limit(self, params: EnvParams) -> int:
return 16 * params.height**2

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, _key = jax.random.split(key)
keys = jax.random.split(_key, num=7)

Expand Down
6 changes: 4 additions & 2 deletions src/xminigrid/envs/minigrid/doorkey.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

Expand All @@ -12,7 +14,7 @@
_rule_encoding = EmptyRule().encode()[None, ...]


class DoorKey(Environment[EnvParams]):
class DoorKey(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=5, width=5)
default_params = default_params.replace(**kwargs)
Expand All @@ -21,7 +23,7 @@ def default_params(self, **kwargs) -> EnvParams:
def time_limit(self, params: EnvParams) -> int:
return 10 * (params.height * params.width)

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, _key = jax.random.split(key)
keys = jax.random.split(_key, num=4)

Expand Down
10 changes: 6 additions & 4 deletions src/xminigrid/envs/minigrid/empty.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

Expand All @@ -13,7 +15,7 @@
_rule_encoding = EmptyRule().encode()[None, ...]


class Empty(Environment[EnvParams]):
class Empty(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=9, width=9)
default_params = default_params.replace(**kwargs)
Expand All @@ -22,7 +24,7 @@ def default_params(self, **kwargs) -> EnvParams:
def time_limit(self, params: EnvParams) -> int:
return 4 * (params.height * params.width)

def _generate_problem(self, params: EnvParams, key: jax.Array):
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
grid = room(params.height, params.width)

grid = grid.at[params.height - 2, params.width - 2].set(TILES_REGISTRY[Tiles.GOAL, Colors.GREEN])
Expand All @@ -40,7 +42,7 @@ def _generate_problem(self, params: EnvParams, key: jax.Array):
return state


class EmptyRandom(Environment[EnvParams]):
class EmptyRandom(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=9, width=9)
default_params = default_params.replace(**kwargs)
Expand All @@ -49,7 +51,7 @@ def default_params(self, **kwargs) -> EnvParams:
def time_limit(self, params: EnvParams) -> int:
return 4 * (params.height * params.width)

def _generate_problem(self, params: EnvParams, key: jax.Array):
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, pos_key, dir_key = jax.random.split(key, num=3)

grid = room(params.height, params.width)
Expand Down
6 changes: 4 additions & 2 deletions src/xminigrid/envs/minigrid/fourrooms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

Expand All @@ -13,7 +15,7 @@
_rule_encoding = EmptyRule().encode()[None, ...]


class FourRooms(Environment[EnvParams]):
class FourRooms(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=19, width=19)
default_params = default_params.replace(**kwargs)
Expand All @@ -23,7 +25,7 @@ def time_limit(self, params: EnvParams) -> int:
# TODO: this is hardcoded and thus problematic. Move it to EnvParams?
return 100

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, *keys = jax.random.split(key, num=4)

grid = four_rooms(params.height, params.width)
Expand Down
6 changes: 4 additions & 2 deletions src/xminigrid/envs/minigrid/lockedroom.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -26,7 +28,7 @@
)


class LockedRoom(Environment[EnvParams]):
class LockedRoom(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=19, width=19)
default_params = default_params.replace(**kwargs)
Expand All @@ -35,7 +37,7 @@ def default_params(self, **kwargs) -> EnvParams:
def time_limit(self, params: EnvParams) -> int:
return 10 * params.height

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, rooms_key, colors_key, objects_key, coords_key, agent_pos_key, agent_dir_key = jax.random.split(key, num=7)

# set up rooms
Expand Down
16 changes: 11 additions & 5 deletions src/xminigrid/envs/minigrid/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import jax
import jax.numpy as jnp
from flax import struct

from ...core.actions import take_action
from ...core.constants import TILES_REGISTRY, Colors, Tiles
Expand All @@ -8,7 +11,7 @@
from ...core.observation import transparent_field_of_view
from ...core.rules import EmptyRule
from ...environment import Environment, EnvParams
from ...types import AgentState, EnvCarry, IntOrArray, State, StepType, TimeStep
from ...types import AgentState, IntOrArray, State, StepType, TimeStep # , EnvCarry

_goal_encoding = EmptyGoal().encode()
_rule_encoding = EmptyRule().encode()[None, ...]
Expand All @@ -27,13 +30,14 @@

# It can be made to be a goal, but for demonstration
# purposes (how to use carry) we decided to leave it as is
class MemoryEnvCarry(EnvCarry):
# class MemoryEnvCarry(EnvCarry):
class MemoryEnvCarry(struct.PyTreeNode):
success_pos: jax.Array
failure_pos: jax.Array


# TODO: Random corridor length is a bit problematic due to the dynamic slicing.
class Memory(Environment[EnvParams]):
class Memory(Environment[EnvParams, MemoryEnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=7, width=13, view_size=3)
default_params = default_params.replace(**kwargs)
Expand All @@ -42,7 +46,7 @@ def default_params(self, **kwargs) -> EnvParams:
def time_limit(self, params: EnvParams) -> int:
return 5 * params.width**2

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[MemoryEnvCarry]:
key, corridor_key, agent_key, mem_key, place_key = jax.random.split(key, num=5)

corridor_length = params.width - 6
Expand Down Expand Up @@ -101,7 +105,9 @@ def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
)
return state

def step(self, params: EnvParams, timestep: TimeStep, action: IntOrArray) -> TimeStep:
def step(
self, params: EnvParams, timestep: TimeStep[MemoryEnvCarry], action: IntOrArray
) -> TimeStep[MemoryEnvCarry]:
# disabling pick_up action
action = jax.lax.select(jnp.equal(action, 3), 5, action)
new_grid, new_agent, _ = take_action(timestep.state.grid, timestep.state.agent, action)
Expand Down
6 changes: 4 additions & 2 deletions src/xminigrid/envs/minigrid/playground.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import jax
import jax.numpy as jnp
from flax import struct
Expand Down Expand Up @@ -41,14 +43,14 @@ class PlaygroundEnvParams(EnvParams):
num_objects: int = struct.field(pytree_node=False, default=12)


class Playground(Environment[PlaygroundEnvParams]):
class Playground(Environment[PlaygroundEnvParams, EnvCarry]):
def default_params(self, **kwargs) -> PlaygroundEnvParams:
return PlaygroundEnvParams(height=19, width=19).replace(**kwargs)

def time_limit(self, params: EnvParams) -> int:
return 512

def _generate_problem(self, params: PlaygroundEnvParams, key: jax.Array) -> State:
def _generate_problem(self, params: PlaygroundEnvParams, key: jax.Array) -> State[EnvCarry]:
key, *keys = jax.random.split(key, num=6)

grid = nine_rooms(params.height, params.width)
Expand Down
6 changes: 4 additions & 2 deletions src/xminigrid/envs/minigrid/unlock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

Expand All @@ -22,7 +24,7 @@
_rule_encoding = EmptyRule().encode()[None, ...]


class Unlock(Environment[EnvParams]):
class Unlock(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=6, width=11)
default_params = default_params.replace(**kwargs)
Expand All @@ -31,7 +33,7 @@ def default_params(self, **kwargs) -> EnvParams:
def time_limit(self, params: EnvParams) -> int:
return 8 * params.height**2

def _generate_problem(self, params: EnvParams, key: jax.Array):
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, *keys = jax.random.split(key, num=5)

color = jax.random.choice(keys[0], _allowed_colors)
Expand Down
6 changes: 4 additions & 2 deletions src/xminigrid/envs/minigrid/unlockpickup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -31,7 +33,7 @@
_rule_encoding = EmptyRule().encode()[None, ...]


class UnlockPickUp(Environment[EnvParams]):
class UnlockPickUp(Environment[EnvParams, EnvCarry]):
def default_params(self, **kwargs) -> EnvParams:
default_params = EnvParams(height=6, width=11)
default_params = default_params.replace(**kwargs)
Expand All @@ -40,7 +42,7 @@ def default_params(self, **kwargs) -> EnvParams:
def time_limit(self, params: EnvParams) -> int:
return 8 * params.height**2

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
def _generate_problem(self, params: EnvParams, key: jax.Array) -> State[EnvCarry]:
key, *keys = jax.random.split(key, num=7)

obj = jax.random.choice(keys[0], _allowed_entities)
Expand Down
Loading

0 comments on commit ef78964

Please sign in to comment.