Skip to content

Commit

Permalink
🎉 Solve all Hcraft envs with HEBG
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathïs Fédérico committed Feb 2, 2024
1 parent 7d5b29f commit 5b05e85
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 16 deletions.
55 changes: 43 additions & 12 deletions src/hcraft/solving_behaviors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,55 @@
PlaceItem,
ReachZone,
)
from hcraft.requirements import RequirementNode, req_node_name
from hcraft.task import GetItemTask, GoToZoneTask, PlaceItemTask, Task

from hebg.unrolling import unroll_graph

if TYPE_CHECKING:
from hcraft.env import HcraftEnv
from hcraft.world import World


def build_all_solving_behaviors(world: "World") -> Dict[str, "Behavior"]:
def build_all_solving_behaviors(env: "HcraftEnv") -> Dict[str, "Behavior"]:
"""Return a dictionary of handcrafted behaviors to get each item, zone and property."""
all_behaviors = {}
all_behaviors = _reach_zones_behaviors(world, all_behaviors)
all_behaviors = _get_item_behaviors(world, all_behaviors)
all_behaviors = _drop_item_behaviors(world, all_behaviors)
all_behaviors = _get_zone_item_behaviors(world, all_behaviors)
all_behaviors = _do_transfo_behaviors(world, all_behaviors)
all_behaviors = _reach_zones_behaviors(env, all_behaviors)
all_behaviors = _get_item_behaviors(env, all_behaviors)
all_behaviors = _drop_item_behaviors(env, all_behaviors)
all_behaviors = _get_zone_item_behaviors(env, all_behaviors)
all_behaviors = _do_transfo_behaviors(env, all_behaviors)

# empty_behaviors = []
# for name, behavior in all_behaviors.items():
# try:
# behavior.graph
# except ValueError:
# empty_behaviors.append(name)
# for name in empty_behaviors:
# all_behaviors.pop(name)

# TODO: Use learning complexity instead for more generality
requirements_graph = env.world.requirements.graph

for behavior in all_behaviors.values():
if isinstance(behavior, AbleAndPerformTransformation):
behavior.complexity = 1
continue
if isinstance(behavior, GetItem):
req_node = req_node_name(behavior.item, RequirementNode.ITEM)
elif isinstance(behavior, DropItem):
# TODO: this clearly is not general enough,
# it would need requirements for non-accumulative to be fine
req_node = req_node_name(behavior.item, RequirementNode.ITEM)
elif isinstance(behavior, ReachZone):
req_node = req_node_name(behavior.zone, RequirementNode.ZONE)
elif isinstance(behavior, PlaceItem):
req_node = req_node_name(behavior.item, RequirementNode.ZONE_ITEM)
else:
raise NotImplementedError
behavior.complexity = requirements_graph.nodes[req_node]["level"]
continue

return all_behaviors


Expand Down Expand Up @@ -109,11 +143,8 @@ def _get_item_behaviors(env: "HcraftEnv", all_behaviors: Dict[str, "Behavior"]):

def _drop_item_behaviors(env: "HcraftEnv", all_behaviors: Dict[str, "Behavior"]):
for item in env.world.items:
try:
behavior = DropItem(item, env, all_behaviors=all_behaviors)
all_behaviors[behavior.name] = behavior
except ValueError:
continue
behavior = DropItem(item, env, all_behaviors=all_behaviors)
all_behaviors[behavior.name] = behavior
return all_behaviors


Expand Down
12 changes: 10 additions & 2 deletions tests/solving_behaviors/test_can_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def test_can_solve(env_class):
env: HcraftEnv = env_class(max_step=50)
draw_call_graph = False

if draw_call_graph:
_fig, ax = plt.subplots()

done = False
observation = env.reset()
for task in env.purpose.best_terminal_group.tasks:
Expand All @@ -24,11 +27,16 @@ def test_can_solve(env_class):
while not task_done and not done:
action = solving_behavior(observation)
if draw_call_graph:
_fig, ax = plt.subplots()
plt.cla()
solving_behavior.graph.call_graph.draw(ax)
plt.show()
plt.show(block=False)

if action == "Impossible":
raise ValueError("Solving behavior could not find a solution.")
observation, _reward, done, _ = env.step(action)
task_done = task.terminated

if draw_call_graph:
plt.show()

check.is_true(env.purpose.terminated)
1 change: 0 additions & 1 deletion tests/solving_behaviors/test_doc_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


@pytest.mark.slow
@pytest.mark.xfail(reason="Hebg needs to handle breadth first search")
def test_doc_example():
from hcraft.examples import MineHcraftEnv
from hcraft.examples.minecraft.items import DIAMOND
Expand Down
14 changes: 13 additions & 1 deletion tests/solving_behaviors/test_mineHcraft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from matplotlib import pyplot as plt
from hcraft.examples.minecraft.env import MineHcraftEnv
from hcraft.task import Task

Expand All @@ -14,9 +15,13 @@


@pytest.mark.slow
@pytest.mark.xfail(reason="Destination items are not taken into account")
def test_solving_behaviors():
"""All tasks should be solved by their solving behavior."""
draw_call_graph = False

if draw_call_graph:
_fig, ax = plt.subplots()

env = MineHcraftEnv(purpose="all", max_step=500)
done = False
observation = env.reset()
Expand All @@ -35,10 +40,17 @@ def test_solving_behaviors():
print(f"Task started: {task} (step={env.current_step})")
solving_behavior = env.solving_behavior(task)
action = solving_behavior(observation)
if draw_call_graph:
plt.cla()
solving_behavior.graph.call_graph.draw(ax)
plt.show(block=False)
observation, _rew, done, _infos = env.step(action)
if task.terminated:
print(f"Task finished: {task}, tasks_left: {tasks_left}")
task = None

if draw_call_graph:
plt.show()
if isinstance(task, Task) and not task.terminated:
print(f"Last unfinished task: {task}")
if set(t.name for t in tasks_left) == set(HARD_TASKS):
Expand Down

0 comments on commit 5b05e85

Please sign in to comment.