Skip to content

Commit

Permalink
✅ Add loop tests
Browse files Browse the repository at this point in the history
🧪 Add loop failling without alternative
  • Loading branch information
MathisFederico committed Jan 7, 2024
1 parent 079f547 commit 650528b
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 5 deletions.
File renamed without changes.
69 changes: 69 additions & 0 deletions tests/examples/behaviors/loop_without_alternative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List

from hebg import HEBGraph, Action, FeatureCondition, Behavior


class ReachForest(Behavior):
"""Reach forest"""

def __init__(self) -> None:
"""Reach forest"""
super().__init__("Reach forest")

def build_graph(self) -> HEBGraph:
graph = HEBGraph(self)
is_in_other_zone = FeatureCondition("Is in other zone ?")
graph.add_edge(is_in_other_zone, Behavior("Reach other zone"), index=False)
graph.add_edge(is_in_other_zone, Action("> forest"), index=True)
is_in_other_zone = FeatureCondition("Is in meadow ?")
graph.add_edge(is_in_other_zone, Behavior("Reach meadow"), index=False)
graph.add_edge(is_in_other_zone, Action("> forest"), index=True)
return graph


class ReachOtherZone(Behavior):
"""Reach other zone"""

def __init__(self) -> None:
"""Reach other zone"""
super().__init__("Reach other zone")

def build_graph(self) -> HEBGraph:
graph = HEBGraph(self)
is_in_forest = FeatureCondition("Is in forest ?")
graph.add_edge(is_in_forest, Behavior("Reach forest"), index=False)
graph.add_edge(is_in_forest, Action("> other zone"), index=True)
is_in_other_zone = FeatureCondition("Is in meadow ?")
graph.add_edge(is_in_other_zone, Behavior("Reach meadow"), index=False)
graph.add_edge(is_in_other_zone, Action("> other zone"), index=True)
return graph


class ReachMeadow(Behavior):
"""Reach meadow"""

def __init__(self) -> None:
"""Reach meadow"""
super().__init__("Reach meadow")

def build_graph(self) -> HEBGraph:
graph = HEBGraph(self)
is_in_forest = FeatureCondition("Is in forest ?")
graph.add_edge(is_in_forest, Behavior("Reach forest"), index=False)
graph.add_edge(is_in_forest, Action("> meadow"), index=True)
is_in_other_zone = FeatureCondition("Is in other zone ?")
graph.add_edge(is_in_other_zone, Behavior("Reach other zone"), index=False)
graph.add_edge(is_in_other_zone, Action("> meadow"), index=True)
return graph


def build_looping_behaviors() -> List[Behavior]:
behaviors: List[Behavior] = [
ReachForest(),
ReachOtherZone(),
ReachMeadow(),
]
all_behaviors = {behavior.name: behavior for behavior in behaviors}
for behavior in behaviors:
behavior.graph.all_behaviors = all_behaviors
return behaviors
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hebg import HEBGraph
from hebg.unrolling import unroll_graph

from tests.examples.behaviors.loop import build_looping_behaviors
from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors

import matplotlib.pyplot as plt

Expand Down
45 changes: 45 additions & 0 deletions tests/integration/test_loop_without_alternative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import pytest_check as check

import networkx as nx
from hebg import HEBGraph
from hebg.unrolling import unroll_graph

from tests.examples.behaviors.loop_without_alternative import build_looping_behaviors

import matplotlib.pyplot as plt


class TestLoop:
"""Tests for the loop example"""

@pytest.fixture(autouse=True)
def setup_method(self):
(
self.reach_forest,
self.reach_other_zone,
self.reach_meadow,
) = build_looping_behaviors()

@pytest.mark.xfail
def test_unroll_reach_forest(self):
draw = False
unrolled_graph = unroll_graph(
self.reach_forest.graph,
add_prefix=True,
cut_looping_alternatives=True,
)
if draw:
_plot_graph(unrolled_graph)

expected_graph = nx.DiGraph()
check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))


def _plot_graph(graph: "HEBGraph"):
_, ax = plt.subplots()
pos = None
if len(graph.roots) == 0:
pos = nx.spring_layout(graph)
graph.draw(ax, pos=pos)
plt.show()
7 changes: 3 additions & 4 deletions tests/integration/test_paper_basic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,9 @@ def test_learning_complexity(self):
f"{behavior}: {c_learning}|{expected_learning_complexities[behavior]}"
f" {saved_complexity}|{expected_saved_complexities[behavior]}"
)
diff_complexity = abs(c_learning - expected_learning_complexities[behavior])
diff_saved = abs(saved_complexity - expected_saved_complexities[behavior])
check.less(diff_complexity, 1e-14)
check.less(diff_saved, 1e-14)

check.almost_equal(c_learning, expected_learning_complexities[behavior])
check.almost_equal(saved_complexity, expected_saved_complexities[behavior])

def test_codegen(self):
expected_code = "\n".join(
Expand Down

0 comments on commit 650528b

Please sign in to comment.