Skip to content

Commit

Permalink
🚧 Attempt at unrolling loop behaviors without direct alternatives
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathïs Fédérico committed Jan 29, 2024
1 parent e6333d3 commit cc11b6c
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 63 deletions.
7 changes: 5 additions & 2 deletions src/hebg/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from matplotlib.axes import Axes
from matplotlib.legend import Legend
from matplotlib.legend_handler import HandlerPatch
from networkx import draw_networkx_edges
from networkx import draw_networkx_edges, spring_layout
from scipy.spatial import ConvexHull # pylint: disable=no-name-in-module

from hebg.graph import draw_networkx_nodes_images
Expand All @@ -37,7 +37,10 @@ def draw_hebgraph(
plt.setp(ax.spines.values(), color="orange")

if pos is None:
pos = staircase_layout(graph)
if len(graph.roots) > 0:
pos = staircase_layout(graph)
else:
pos = spring_layout(graph)
draw_networkx_nodes_images(graph, pos, ax=ax, img_zoom=0.5)

draw_networkx_edges(
Expand Down
59 changes: 42 additions & 17 deletions src/hebg/unrolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _unroll_graph(
if _unrolled_behaviors is None:
_unrolled_behaviors = {}
if _current_alternatives is None:
_current_alternatives = []
_current_alternatives = {0: []}

is_looping = False
_unrolled_behaviors[graph.behavior.name] = None
Expand All @@ -67,14 +67,9 @@ def _unroll_graph(
for node in list(unrolled_graph.nodes()):
if not isinstance(node, Behavior):
continue
new_alternatives = []
for pred, _node, data in graph.in_edges(node, data=True):
index = data["index"]
for _pred, alternative, alt_index in graph.out_edges(pred, data="index"):
if index == alt_index and alternative != node:
new_alternatives.append(alternative)
if new_alternatives:
_current_alternatives = new_alternatives

_current_alternatives[0] = _direct_alternatives(node, graph)
_current_alternatives[1] = _roots_alternatives(node, graph)
unrolled_graph, behavior_is_looping = _unroll_behavior(
unrolled_graph,
node,
Expand All @@ -89,6 +84,25 @@ def _unroll_graph(
return unrolled_graph, is_looping


def _direct_alternatives(node: "Node", graph: "HEBGraph"):
alternatives = []
for pred, _node, data in graph.in_edges(node, data=True):
index = data["index"]
for _pred, alternative, alt_index in graph.out_edges(pred, data="index"):
if index != alt_index or alternative == node:
continue
alternatives.append(alternative)
return alternatives


def _roots_alternatives(node: "Node", graph: "HEBGraph"):
alternatives = []
for pred, _node, data in graph.in_edges(node, data=True):
if pred in graph.roots:
alternatives.extend([r for r in graph.roots if r != pred])
return alternatives


def _unroll_behavior(
graph: "HEBGraph",
behavior: "Behavior",
Expand Down Expand Up @@ -121,13 +135,24 @@ def _unroll_behavior(
)

if is_looping and cut_looping_alternatives:
if not _current_alternatives:
return graph, is_looping
for alternative in _current_alternatives:
for alternative in _current_alternatives[0]:
for last_condition, _, data in graph.in_edges(behavior, data=True):
graph.add_edge(last_condition, alternative, **data)
graph.remove_node(behavior)
return graph, False
if _current_alternatives[0]:
graph.remove_node(behavior)
return graph, False
if _current_alternatives[1]:
predecessors = list(graph.predecessors(behavior))
for last_condition in predecessors:
successors = list(graph.successors(last_condition))
for descendant in successors:
graph.remove_edge(last_condition, descendant)
if graph.neighbors(descendant) == 0:
graph.remove_node(descendant)
graph.remove_node(last_condition)
graph.remove_node(behavior)
return graph, False
raise NotImplementedError()

if node_graph is None:
# If we cannot get the node's graph, we keep it as is.
Expand All @@ -153,7 +178,7 @@ def _unrolled_behavior_graph(
cut_looping_alternatives: bool,
_current_alternatives: List[Union["Action", "Behavior"]],
_unrolled_behaviors: Dict[str, Optional["HEBGraph"]],
) -> Optional["HEBGraph"]:
) -> Tuple[Optional["HEBGraph"], bool]:
"""Get the unrolled sub-graph of a behavior.
Args:
Expand Down Expand Up @@ -218,9 +243,9 @@ def group_behaviors_points(
for i in range(len(groups[:-1])):
key = tuple(groups[: -1 - i])
point = pos[node]
try:
if key in points_grouped_by_behavior:
points_grouped_by_behavior[key].append(point)
except KeyError:
else:
points_grouped_by_behavior[key] = [point]
return points_grouped_by_behavior

Expand Down
6 changes: 1 addition & 5 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Protocol
from matplotlib import pyplot as plt
import networkx as nx


class Graph(Protocol):
Expand All @@ -18,9 +17,6 @@ def nodes(self) -> list:

def plot_graph(graph: Graph, **kwargs):
_, ax = plt.subplots()
pos = None
if len(list(graph.nodes())) == 0:
pos = nx.spring_layout(graph)
graph.draw(ax, pos=pos, **kwargs)
graph.draw(ax, **kwargs)
plt.axis("off") # turn off axis
plt.show()
4 changes: 2 additions & 2 deletions tests/examples/behaviors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tests.examples.behaviors.binary_sum import build_binary_sum_behavior
from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
from tests.examples.behaviors.loop_without_alternative import (
build_looping_behaviors_without_alternatives,
build_looping_behaviors_without_direct_alternatives,
)


Expand All @@ -28,5 +28,5 @@
"E_E_A_Behavior",
"build_binary_sum_behavior",
"build_looping_behaviors",
"build_looping_behaviors_without_alternatives",
"build_looping_behaviors_without_direct_alternatives",
]
2 changes: 1 addition & 1 deletion tests/examples/behaviors/loop_without_alternative.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build_graph(self) -> HEBGraph:
return graph


def build_looping_behaviors_without_alternatives() -> List[Behavior]:
def build_looping_behaviors_without_direct_alternatives() -> List[Behavior]:
behaviors: List[Behavior] = [
ReachForest(),
ReachOtherZone(),
Expand Down
77 changes: 41 additions & 36 deletions tests/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
from tests.examples.behaviors.loop_without_alternative import (
build_looping_behaviors_without_alternatives,
build_looping_behaviors_without_direct_alternatives,
)


Expand All @@ -20,9 +20,9 @@ def setup_method(self):

def test_unroll_gather_wood(self):
draw = False
unrolled_graph = unroll_graph(self.gather_wood.graph)
unrolled_graph = unroll_graph(self.gather_wood.graph, add_prefix=True)
if draw:
plot_graph(unrolled_graph)
plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)

expected_graph = nx.DiGraph()
expected_graph.add_edge("Has axe", "Punch tree")
Expand All @@ -37,9 +37,9 @@ def test_unroll_gather_wood(self):

def test_unroll_get_new_axe(self):
draw = False
unrolled_graph = unroll_graph(self.get_new_axe.graph)
unrolled_graph = unroll_graph(self.get_new_axe.graph, add_prefix=True)
if draw:
plot_graph(unrolled_graph)
plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)

expected_graph = nx.DiGraph()
expected_graph.add_edge("Has wood", "Has axe")
Expand All @@ -55,10 +55,10 @@ def test_unroll_get_new_axe(self):
def test_unroll_gather_wood_cutting_alternatives(self):
draw = False
unrolled_graph = unroll_graph(
self.gather_wood.graph, cut_looping_alternatives=True
self.gather_wood.graph, add_prefix=True, cut_looping_alternatives=True
)
if draw:
plot_graph(unrolled_graph)
plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)

expected_graph = nx.DiGraph()
expected_graph.add_edge("Has axe", "Punch tree")
Expand All @@ -74,44 +74,49 @@ def test_unroll_gather_wood_cutting_alternatives(self):
def test_unroll_get_new_axe_cutting_alternatives(self):
draw = False
unrolled_graph = unroll_graph(
self.get_new_axe.graph,
cut_looping_alternatives=True,
self.get_new_axe.graph, add_prefix=True, cut_looping_alternatives=True
)
if draw:
plot_graph(unrolled_graph)

expected_graph = nx.DiGraph()
expected_graph.add_edge("Has wood", "Has axe")
expected_graph.add_edge("Has wood", "Craft new axe")
expected_graph.add_edge("Has wood", "Summon axe out of thin air")

# Expected sub-behavior
expected_graph.add_edge("Has axe", "Punch tree")
expected_graph.add_edge("Has axe", "Cut tree with axe")
plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)

expected_graph = nx.DiGraph(
[
("Has wood", "Has axe"),
("Has wood", "Craft new axe"),
("Has wood", "Summon axe out of thin air"),
# Expected sub-behavior
("Has axe", "Punch tree"),
("Has axe", "Cut tree with axe"),
]
)
check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))


class TestLoopWithoutAlternative:
"""Tests for the loop without alternative example"""

@pytest.fixture(autouse=True)
def setup_method(self):
(
self.reach_forest,
self.reach_other_zone,
self.reach_meadow,
) = build_looping_behaviors_without_alternatives()

@pytest.mark.xfail
def test_unroll_reach_forest(self):
def test_unroll_root_alternative_reach_forest(self):
(
reach_forest,
_reach_other_zone,
_reach_meadow,
) = build_looping_behaviors_without_direct_alternatives()
draw = False
unrolled_graph = unroll_graph(
self.reach_forest.graph,
reach_forest.graph,
add_prefix=True,
cut_looping_alternatives=True,
)
if draw:
plot_graph(unrolled_graph)

expected_graph = nx.DiGraph()
plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)

expected_graph = nx.DiGraph(
[
# ("Root", "Is in other zone ?"),
# ("Root", "Is in meadow ?"),
("Is in other zone ?", "Reach other zone"),
("Is in other zone ?", "Go to forest"),
("Is in meadow ?", "Go to forest"),
("Is in meadow ?", "Reach meadow>Is in other zones ?"),
("Reach meadow>Is in other zone ?", "Reach meadow>Reach other zone"),
("Reach meadow>Is in other zone ?", "Reach meadow>Go to forest"),
]
)
check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))

0 comments on commit cc11b6c

Please sign in to comment.