diff --git a/src/hebg/draw.py b/src/hebg/draw.py index 6f78920..80efd7d 100644 --- a/src/hebg/draw.py +++ b/src/hebg/draw.py @@ -10,7 +10,7 @@ from matplotlib.axes import Axes from matplotlib.legend import Legend from matplotlib.legend_handler import HandlerPatch -from networkx import draw_networkx_edges +from networkx import draw_networkx_edges, spring_layout from scipy.spatial import ConvexHull # pylint: disable=no-name-in-module from hebg.graph import draw_networkx_nodes_images @@ -37,7 +37,10 @@ def draw_hebgraph( plt.setp(ax.spines.values(), color="orange") if pos is None: - pos = staircase_layout(graph) + if len(graph.roots) > 0: + pos = staircase_layout(graph) + else: + pos = spring_layout(graph) draw_networkx_nodes_images(graph, pos, ax=ax, img_zoom=0.5) draw_networkx_edges( diff --git a/src/hebg/unrolling.py b/src/hebg/unrolling.py index dcc039d..b7aa976 100644 --- a/src/hebg/unrolling.py +++ b/src/hebg/unrolling.py @@ -58,7 +58,7 @@ def _unroll_graph( if _unrolled_behaviors is None: _unrolled_behaviors = {} if _current_alternatives is None: - _current_alternatives = [] + _current_alternatives = {0: []} is_looping = False _unrolled_behaviors[graph.behavior.name] = None @@ -67,14 +67,9 @@ def _unroll_graph( for node in list(unrolled_graph.nodes()): if not isinstance(node, Behavior): continue - new_alternatives = [] - for pred, _node, data in graph.in_edges(node, data=True): - index = data["index"] - for _pred, alternative, alt_index in graph.out_edges(pred, data="index"): - if index == alt_index and alternative != node: - new_alternatives.append(alternative) - if new_alternatives: - _current_alternatives = new_alternatives + + _current_alternatives[0] = _direct_alternatives(node, graph) + _current_alternatives[1] = _roots_alternatives(node, graph) unrolled_graph, behavior_is_looping = _unroll_behavior( unrolled_graph, node, @@ -89,6 +84,25 @@ def _unroll_graph( return unrolled_graph, is_looping +def _direct_alternatives(node: "Node", graph: "HEBGraph"): + alternatives = [] + for pred, _node, data in graph.in_edges(node, data=True): + index = data["index"] + for _pred, alternative, alt_index in graph.out_edges(pred, data="index"): + if index != alt_index or alternative == node: + continue + alternatives.append(alternative) + return alternatives + + +def _roots_alternatives(node: "Node", graph: "HEBGraph"): + alternatives = [] + for pred, _node, data in graph.in_edges(node, data=True): + if pred in graph.roots: + alternatives.extend([r for r in graph.roots if r != pred]) + return alternatives + + def _unroll_behavior( graph: "HEBGraph", behavior: "Behavior", @@ -121,13 +135,24 @@ def _unroll_behavior( ) if is_looping and cut_looping_alternatives: - if not _current_alternatives: - return graph, is_looping - for alternative in _current_alternatives: + for alternative in _current_alternatives[0]: for last_condition, _, data in graph.in_edges(behavior, data=True): graph.add_edge(last_condition, alternative, **data) - graph.remove_node(behavior) - return graph, False + if _current_alternatives[0]: + graph.remove_node(behavior) + return graph, False + if _current_alternatives[1]: + predecessors = list(graph.predecessors(behavior)) + for last_condition in predecessors: + successors = list(graph.successors(last_condition)) + for descendant in successors: + graph.remove_edge(last_condition, descendant) + if graph.neighbors(descendant) == 0: + graph.remove_node(descendant) + graph.remove_node(last_condition) + graph.remove_node(behavior) + return graph, False + raise NotImplementedError() if node_graph is None: # If we cannot get the node's graph, we keep it as is. @@ -153,7 +178,7 @@ def _unrolled_behavior_graph( cut_looping_alternatives: bool, _current_alternatives: List[Union["Action", "Behavior"]], _unrolled_behaviors: Dict[str, Optional["HEBGraph"]], -) -> Optional["HEBGraph"]: +) -> Tuple[Optional["HEBGraph"], bool]: """Get the unrolled sub-graph of a behavior. Args: @@ -218,9 +243,9 @@ def group_behaviors_points( for i in range(len(groups[:-1])): key = tuple(groups[: -1 - i]) point = pos[node] - try: + if key in points_grouped_by_behavior: points_grouped_by_behavior[key].append(point) - except KeyError: + else: points_grouped_by_behavior[key] = [point] return points_grouped_by_behavior diff --git a/tests/__init__.py b/tests/__init__.py index 20e5652..f69ceea 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,7 +5,6 @@ from typing import Protocol from matplotlib import pyplot as plt -import networkx as nx class Graph(Protocol): @@ -18,9 +17,6 @@ def nodes(self) -> list: def plot_graph(graph: Graph, **kwargs): _, ax = plt.subplots() - pos = None - if len(list(graph.nodes())) == 0: - pos = nx.spring_layout(graph) - graph.draw(ax, pos=pos, **kwargs) + graph.draw(ax, **kwargs) plt.axis("off") # turn off axis plt.show() diff --git a/tests/examples/behaviors/__init__.py b/tests/examples/behaviors/__init__.py index 2a5c779..d339dd3 100644 --- a/tests/examples/behaviors/__init__.py +++ b/tests/examples/behaviors/__init__.py @@ -13,7 +13,7 @@ from tests.examples.behaviors.binary_sum import build_binary_sum_behavior from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors from tests.examples.behaviors.loop_without_alternative import ( - build_looping_behaviors_without_alternatives, + build_looping_behaviors_without_direct_alternatives, ) @@ -28,5 +28,5 @@ "E_E_A_Behavior", "build_binary_sum_behavior", "build_looping_behaviors", - "build_looping_behaviors_without_alternatives", + "build_looping_behaviors_without_direct_alternatives", ] diff --git a/tests/examples/behaviors/loop_without_alternative.py b/tests/examples/behaviors/loop_without_alternative.py index c0715f5..94ee273 100644 --- a/tests/examples/behaviors/loop_without_alternative.py +++ b/tests/examples/behaviors/loop_without_alternative.py @@ -57,7 +57,7 @@ def build_graph(self) -> HEBGraph: return graph -def build_looping_behaviors_without_alternatives() -> List[Behavior]: +def build_looping_behaviors_without_direct_alternatives() -> List[Behavior]: behaviors: List[Behavior] = [ ReachForest(), ReachOtherZone(), diff --git a/tests/test_loop.py b/tests/test_loop.py index 8f451ca..5ab85db 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -7,7 +7,7 @@ from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors from tests.examples.behaviors.loop_without_alternative import ( - build_looping_behaviors_without_alternatives, + build_looping_behaviors_without_direct_alternatives, ) @@ -20,9 +20,9 @@ def setup_method(self): def test_unroll_gather_wood(self): draw = False - unrolled_graph = unroll_graph(self.gather_wood.graph) + unrolled_graph = unroll_graph(self.gather_wood.graph, add_prefix=True) if draw: - plot_graph(unrolled_graph) + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) expected_graph = nx.DiGraph() expected_graph.add_edge("Has axe", "Punch tree") @@ -37,9 +37,9 @@ def test_unroll_gather_wood(self): def test_unroll_get_new_axe(self): draw = False - unrolled_graph = unroll_graph(self.get_new_axe.graph) + unrolled_graph = unroll_graph(self.get_new_axe.graph, add_prefix=True) if draw: - plot_graph(unrolled_graph) + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) expected_graph = nx.DiGraph() expected_graph.add_edge("Has wood", "Has axe") @@ -55,10 +55,10 @@ def test_unroll_get_new_axe(self): def test_unroll_gather_wood_cutting_alternatives(self): draw = False unrolled_graph = unroll_graph( - self.gather_wood.graph, cut_looping_alternatives=True + self.gather_wood.graph, add_prefix=True, cut_looping_alternatives=True ) if draw: - plot_graph(unrolled_graph) + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) expected_graph = nx.DiGraph() expected_graph.add_edge("Has axe", "Punch tree") @@ -74,44 +74,49 @@ def test_unroll_gather_wood_cutting_alternatives(self): def test_unroll_get_new_axe_cutting_alternatives(self): draw = False unrolled_graph = unroll_graph( - self.get_new_axe.graph, - cut_looping_alternatives=True, + self.get_new_axe.graph, add_prefix=True, cut_looping_alternatives=True ) if draw: - plot_graph(unrolled_graph) - - expected_graph = nx.DiGraph() - expected_graph.add_edge("Has wood", "Has axe") - expected_graph.add_edge("Has wood", "Craft new axe") - expected_graph.add_edge("Has wood", "Summon axe out of thin air") - - # Expected sub-behavior - expected_graph.add_edge("Has axe", "Punch tree") - expected_graph.add_edge("Has axe", "Cut tree with axe") + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) + + expected_graph = nx.DiGraph( + [ + ("Has wood", "Has axe"), + ("Has wood", "Craft new axe"), + ("Has wood", "Summon axe out of thin air"), + # Expected sub-behavior + ("Has axe", "Punch tree"), + ("Has axe", "Cut tree with axe"), + ] + ) check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) - -class TestLoopWithoutAlternative: - """Tests for the loop without alternative example""" - - @pytest.fixture(autouse=True) - def setup_method(self): - ( - self.reach_forest, - self.reach_other_zone, - self.reach_meadow, - ) = build_looping_behaviors_without_alternatives() - @pytest.mark.xfail - def test_unroll_reach_forest(self): + def test_unroll_root_alternative_reach_forest(self): + ( + reach_forest, + _reach_other_zone, + _reach_meadow, + ) = build_looping_behaviors_without_direct_alternatives() draw = False unrolled_graph = unroll_graph( - self.reach_forest.graph, + reach_forest.graph, add_prefix=True, cut_looping_alternatives=True, ) if draw: - plot_graph(unrolled_graph) - - expected_graph = nx.DiGraph() + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) + + expected_graph = nx.DiGraph( + [ + # ("Root", "Is in other zone ?"), + # ("Root", "Is in meadow ?"), + ("Is in other zone ?", "Reach other zone"), + ("Is in other zone ?", "Go to forest"), + ("Is in meadow ?", "Go to forest"), + ("Is in meadow ?", "Reach meadow>Is in other zones ?"), + ("Reach meadow>Is in other zone ?", "Reach meadow>Reach other zone"), + ("Reach meadow>Is in other zone ?", "Reach meadow>Go to forest"), + ] + ) check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))