Skip to content

Commit

Permalink
⚡️ Remove envs instanciations during module init (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
MathisFederico authored Jan 16, 2024
2 parents a283ef3 + 1af3f87 commit 58117b0
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 60 deletions.
21 changes: 6 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,14 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.254'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- repo: local
hooks:
- id: pytest-fast-check
name: pytest-fast-check
entry: ./venv/Scripts/python.exe -m pytest -m "not slow"
stages: ["commit"]
language: system
pass_filenames: false
always_run: true
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi, jupyter ]
- repo: local
hooks:
- id: pytest-check
Expand Down
12 changes: 7 additions & 5 deletions src/hcraft/examples/minecraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
from typing import Optional

import hcraft.examples.minecraft.items as items
from hcraft.examples.minecraft.env import MineHcraftEnv
from hcraft.purpose import Purpose, RewardShaping, platinium_purpose
from hcraft.examples.minecraft.env import ALL_ITEMS, MineHcraftEnv

from hcraft.purpose import Purpose, RewardShaping
from hcraft.task import GetItemTask

MINEHCRAFT_GYM_ENVS = []
__all__ = ["MineHcraftEnv"]


# gym is an optional dependency
try:
import gym

ENV_PATH = "hcraft.examples.minecraft.env:MineHcraftEnv"
MC_WORLD = MineHcraftEnv().world

# Simple MineHcraft with no reward, only penalty on illegal actions
gym.register(
Expand All @@ -34,7 +36,7 @@
gym.register(
id="MineHcraft-v1",
entry_point=ENV_PATH,
kwargs={"purpose": platinium_purpose(MC_WORLD)},
kwargs={"purpose": "all"},
)
MINEHCRAFT_GYM_ENVS.append("MineHcraft-v1")

Expand Down Expand Up @@ -71,7 +73,7 @@ def _register_minehcraft_single_item(
items.ENDER_DRAGON_HEAD: "Dragon",
}

for item in MC_WORLD.items:
for item in ALL_ITEMS:
cap_item_name = "".join([part.capitalize() for part in item.name.split("_")])
item_id = replacement_names.get(item, cap_item_name)
_register_minehcraft_single_item(item, name=item_id)
Expand Down
30 changes: 27 additions & 3 deletions src/hcraft/examples/minecraft/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,26 @@

from hcraft.elements import Stack
from hcraft.env import HcraftEnv
from hcraft.examples.minecraft.items import CLOSE_ENDER_PORTAL, OPEN_NETHER_PORTAL
from hcraft.examples.minecraft.items import (
CLOSE_ENDER_PORTAL,
CRAFTABLE_ITEMS,
MC_FINDABLE_ITEMS,
OPEN_NETHER_PORTAL,
PLACABLE_ITEMS,
)
from hcraft.examples.minecraft.tools import MC_TOOLS
from hcraft.examples.minecraft.transformations import (
build_minehcraft_transformations,
)
from hcraft.examples.minecraft.zones import FOREST, NETHER, STRONGHOLD
from hcraft.examples.minecraft.zones import FOREST, MC_ZONES, NETHER, STRONGHOLD
from hcraft.purpose import platinium_purpose
from hcraft.world import world_from_transformations

ALL_ITEMS = set(
MC_TOOLS + CRAFTABLE_ITEMS + [mcitem.item for mcitem in MC_FINDABLE_ITEMS]
)
"""Set of all items"""


class MineHcraftEnv(HcraftEnv):

Expand All @@ -31,6 +44,9 @@ def __init__(self, **kwargs):
resources_path = os.path.join(mc_dir, "resources")
mc_transformations = build_minehcraft_transformations()
start_zone = kwargs.pop("start_zone", FOREST)
purpose = kwargs.pop("purpose", None)
if purpose == "all":
purpose = get_platinum_purpose()
mc_world = world_from_transformations(
mc_transformations,
start_zone=start_zone,
Expand All @@ -40,5 +56,13 @@ def __init__(self, **kwargs):
},
)
mc_world.resources_path = resources_path
super().__init__(world=mc_world, name="MineHcraft", **kwargs)
super().__init__(world=mc_world, name="MineHcraft", purpose=purpose, **kwargs)
self.metadata["video.frames_per_second"] = kwargs.pop("fps", 10)


def get_platinum_purpose():
return platinium_purpose(
items=list(ALL_ITEMS),
zones=MC_ZONES,
zones_items=PLACABLE_ITEMS,
)
37 changes: 35 additions & 2 deletions src/hcraft/examples/minecraft/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"""CRAFTING_TABLE"""

FURNACE = Item("furnace")
"""CRAFTING_TABLE"""
"""FURNACE"""

STICK = Item("stick")
"""STICK"""
Expand All @@ -64,6 +64,25 @@
"""WOOD_PLANK"""


CRAFTABLE_ITEMS = [
IRON_INGOT,
GOLD_INGOT,
PAPER,
BOOK,
CLOCK,
ENCHANTING_TABLE,
CRAFTING_TABLE,
FURNACE,
STICK,
BLAZE_POWDER,
ENDER_EYE,
FLINT,
FLINT_AND_STEEL,
WOOD_PLANK,
]
"""Items that can be obtained with crafting."""


@dataclass
class McItem:
"""Minecraft item with its specific properties."""
Expand Down Expand Up @@ -224,7 +243,7 @@ class McItem:
)
"""ENDER_DRAGON_HEAD"""

MC_ITEMS = [
MC_FINDABLE_ITEMS = [
MC_DIRT,
MC_WOOD,
MC_GRAVEL,
Expand All @@ -243,6 +262,7 @@ class McItem:
MC_ENDER_PEARL,
MC_ENDER_DRAGON_HEAD,
]
"""McItems that can be gathered with or without tools."""

#: Buildings
CLOSE_NETHER_PORTAL = Item("close_nether_portal")
Expand All @@ -256,3 +276,16 @@ class McItem:

OPEN_ENDER_PORTAL = Item("open_ender_portal")
"""OPEN_ENDER_PORTAL"""
BUIDINGS = [
CLOSE_NETHER_PORTAL,
OPEN_NETHER_PORTAL,
CLOSE_ENDER_PORTAL,
OPEN_ENDER_PORTAL,
]

PLACABLE_ITEMS = [
CRAFTING_TABLE,
FURNACE,
ENCHANTING_TABLE,
] + BUIDINGS
"""Items that can be placed."""
12 changes: 6 additions & 6 deletions src/hcraft/examples/minecraft/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

def build_minehcraft_transformations() -> List[Transformation]:
transformations = []
transformations += _move_to_zones()
transformations += _zones_search()
transformations += _building()
transformations += _recipes()
transformations += _tools_recipes()
transformations += _zones_search()
transformations += _move_to_zones()
return transformations


Expand All @@ -49,7 +49,7 @@ def _move_to_zones() -> List[Transformation]:
destination=zones.SWAMP,
zones=[zone for zone in zones.OVERWORLD if zone != zones.SWAMP],
),
#: Move to zones.MEADOW
#: Move to MEADOW
Transformation(
name_prefix + zones.MEADOW.name,
destination=zones.MEADOW,
Expand Down Expand Up @@ -112,8 +112,8 @@ def _move_to_zones() -> List[Transformation]:
Transformation(
name_prefix + zones.STRONGHOLD.name,
destination=zones.STRONGHOLD,
zones=zones.OVERWORLD,
inventory_changes=[Use(CURRENT_ZONE, items.OPEN_NETHER_PORTAL, consume=2)],
zones=[zone for zone in zones.OVERWORLD if zone != zones.STRONGHOLD],
inventory_changes=[Use(PLAYER, items.ENDER_EYE)],
),
#: Move to zones.END
Transformation(
Expand All @@ -139,7 +139,7 @@ def _zones_search() -> List[Transformation]:

name_prefix = "search-for-"
search_item = []
for mc_item in items.MC_ITEMS:
for mc_item in items.MC_FINDABLE_ITEMS:
item = mc_item.item

if mc_item.required_tool_types is None:
Expand Down
2 changes: 2 additions & 0 deletions src/hcraft/examples/minecraft/zones.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
STRONGHOLD = Zone("stronghold") #: STRONGHOLD

OVERWORLD = [FOREST, SWAMP, MEADOW, UNDERGROUND, BEDROCK, STRONGHOLD]

MC_ZONES = OVERWORLD + [NETHER, END]
3 changes: 1 addition & 2 deletions src/hcraft/examples/minicraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@

ENV_PATH = "hcraft.examples.minicraft"

for env_class in MINICRAFT_ENVS:
env_name = env_class().name
for env_name, env_class in MINICRAFT_NAME_TO_ENV.items():
submodule = Path(inspect.getfile(env_class)).name.split(".")[0]
env_path = f"{ENV_PATH}.{submodule}:{env_class.__name__}"
gym_name = f"{env_name}-v1"
Expand Down
19 changes: 11 additions & 8 deletions src/hcraft/purpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@

from hcraft.requirements import RequirementNode, req_node_name
from hcraft.task import GetItemTask, GoToZoneTask, PlaceItemTask, Task
from hcraft.elements import Item, Zone


if TYPE_CHECKING:
from hcraft.elements import Item, Zone
from hcraft.env import HcraftEnv, HcraftState
from hcraft.world import World

Expand Down Expand Up @@ -353,16 +354,18 @@ def _tasks_str(self, tasks: List[Task]) -> str:


def platinium_purpose(
world: "World",
items: List[Item],
zones: List[Zone],
zones_items: List[Item],
success_reward: float = 10.0,
timestep_reward: float = -0.1,
):
purpose = Purpose(timestep_reward=timestep_reward)
for item in world.items:
for item in items:
purpose.add_task(GetItemTask(item, reward=success_reward))
for zone in world.zones:
for zone in zones:
purpose.add_task(GoToZoneTask(zone, reward=success_reward))
for item in world.zones_items:
for item in zones_items:
purpose.add_task(PlaceItemTask(item, reward=success_reward))
return purpose

Expand Down Expand Up @@ -480,9 +483,9 @@ def _inputs_subtasks(task: Task, world: "World", shaping_reward: float) -> List[


def _build_reward_shaping_subtasks(
items: Optional[Union[List["Item"], Set["Item"]]] = None,
zones: Optional[Union[List["Zone"], Set["Zone"]]] = None,
zone_items: Optional[Union[List["Item"], Set["Item"]]] = None,
items: Optional[Union[List[Item], Set[Item]]] = None,
zones: Optional[Union[List[Zone], Set[Zone]]] = None,
zone_items: Optional[Union[List[Item], Set[Item]]] = None,
shaping_reward: float = 1.0,
) -> List[Task]:
subtasks = []
Expand Down
1 change: 0 additions & 1 deletion src/hcraft/solving_behaviors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def build_all_solving_behaviors(world: "World") -> Dict[str, "Behavior"]:
all_behaviors = {}
all_behaviors = _reach_zones_behaviors(world, all_behaviors)
all_behaviors = _get_item_behaviors(world, all_behaviors)
# all_behaviors = _drop_item_behaviors(world, all_behaviors)
all_behaviors = _get_zone_item_behaviors(world, all_behaviors)
all_behaviors = _do_transfo_behaviors(world, all_behaviors)
return all_behaviors
Expand Down
36 changes: 35 additions & 1 deletion src/hcraft/state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, Optional

import numpy as np

from hcraft.transformation import InventoryOwner

if TYPE_CHECKING:
from hcraft.world import World
from hcraft.elements import Zone, Item
Expand Down Expand Up @@ -105,6 +107,22 @@ def current_zone(self) -> Optional["Zone"]:
def _current_zone_slot(self) -> int:
return self.position.nonzero()[0]

@property
def player_inventory_dict(self) -> Dict["Item", int]:
"""Current inventory of the player."""
return self._inv_as_dict(self.player_inventory, self.world.items)

@property
def zones_inventories_dict(self) -> Dict["Zone", Dict["Item", int]]:
"""Current inventories of the current zone and each zone containing item."""
zones_invs = {}
for zone_slot, zone_inv in enumerate(self.zones_inventories):
zone = self.world.zones[zone_slot]
zone_inv = self._inv_as_dict(zone_inv, self.world.zones_items)
if zone_slot == self._current_zone_slot or zone_inv:
zones_invs[zone] = zone_inv
return zones_invs

def apply(self, action: int) -> bool:
"""Apply the given action to update the state.
Expand Down Expand Up @@ -166,3 +184,19 @@ def _update_discoveries(self, action: Optional[int] = None) -> None:
self.discovered_zones = np.bitwise_or(self.discovered_zones, self.position > 0)
if action is not None:
self.discovered_transformations[action] = 1

@staticmethod
def _inv_as_dict(inventory_array: np.ndarray, obj_registry: list):
return {
obj_registry[index]: value
for index, value in enumerate(inventory_array)
if value > 0
}

def as_dict(self) -> dict:
state_dict = {
"pos": self.current_zone,
InventoryOwner.PLAYER.value: self.player_inventory_dict,
}
state_dict.update(self.zones_inventories_dict)
return state_dict
Loading

0 comments on commit 58117b0

Please sign in to comment.