Skip to content

Commit

Permalink
♻️ Identify CallGraph as its own object
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathïs Fédérico committed Jan 29, 2024
1 parent ddb03c6 commit e6333d3
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 113 deletions.
120 changes: 120 additions & 0 deletions src/hebg/call_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from enum import Enum
from re import S
from typing import Dict, List, Optional, Tuple, Union
from matplotlib.axes import Axes

from networkx import (
DiGraph,
draw_networkx_edges,
draw_networkx_labels,
draw_networkx_nodes,
)
import numpy as np

from hebg.node import Node


class CallEdgeStatus(Enum):
UNEXPLORED = "unexplored"
CALLED = "called"
FAILURE = "failure"


class CallGraph(DiGraph):
def __init__(self, initial_node: Node, **attr):
super().__init__(incoming_graph_data=None, **attr)
self.graph["frontiere"] = []
self.add_node(initial_node.name, order=0)

def extend_frontiere(self, nodes: List[Node], parent: Node):
frontiere: List[Node] = self.graph["frontiere"]
frontiere.extend(nodes)

for node in nodes:
self.add_edge(
parent.name, node.name, status=CallEdgeStatus.UNEXPLORED.value
)
node_data = self.nodes[node.name]
parent_data = self.nodes[parent.name]
if "order" not in node_data:
node_data["order"] = parent_data["order"] + 1

def pop_from_frontiere(self, parent: Node) -> Optional[Node]:
frontiere: List[Node] = self.graph["frontiere"]

next_node = None

while next_node is None:
if not frontiere:
return None
_next_node = frontiere.pop(np.argmin([node.cost for node in frontiere]))

if len(list(self.successors(_next_node))) > 0:
self.update_edge_status(parent, _next_node, CallEdgeStatus.FAILURE)
continue

self.update_edge_status(parent, _next_node, CallEdgeStatus.CALLED)
next_node = _next_node

return next_node

def update_edge_status(
self, start: Node, end: Node, status: Union[CallEdgeStatus, str]
):
status = CallEdgeStatus(status)
self.edges[start.name, end.name]["status"] = status.value

def draw(
self,
ax: Optional[Axes] = None,
pos: Optional[Dict[str, Tuple[float, float]]] = None,
nodes_kwargs: Optional[dict] = None,
label_kwargs: Optional[dict] = None,
edges_kwargs: Optional[dict] = None,
):
if pos is None:
pos = call_graph_pos(self)
if nodes_kwargs is None:
nodes_kwargs = {}
draw_networkx_nodes(self, ax=ax, pos=pos, **nodes_kwargs)
if label_kwargs is None:
label_kwargs = {}
draw_networkx_labels(self, ax=ax, pos=pos, **nodes_kwargs)
if edges_kwargs is None:
edges_kwargs = {}
if "connectionstyle" not in edges_kwargs:
edges_kwargs.update(connectionstyle="arc3,rad=-0.15")
draw_networkx_edges(
self,
ax=ax,
pos=pos,
edge_color=[
call_status_to_color(status)
for _, _, status in self.edges(data="status")
],
**edges_kwargs,
)


def call_status_to_color(status: Union[str, CallEdgeStatus]):
status = CallEdgeStatus(status)
if status is CallEdgeStatus.UNEXPLORED:
return "black"
if status is CallEdgeStatus.CALLED:
return "green"
if status is CallEdgeStatus.FAILURE:
return "red"
raise NotImplementedError


def call_graph_pos(call_graph: DiGraph) -> Dict[str, Tuple[float, float]]:
pos = {}
amount_by_order = {}
for node, node_data in call_graph.nodes(data=True):
order: int = node_data["order"]
if order not in amount_by_order:
amount_by_order[order] = 0
else:
amount_by_order[order] += 1
pos[node] = [order, amount_by_order[order] / 2]
return pos
77 changes: 18 additions & 59 deletions src/hebg/heb_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
""" Module containing the HEBGraph base class. """

from __future__ import annotations
from enum import Enum

from typing import Any, Dict, List, Optional, Tuple, TypeVar

import numpy as np
from matplotlib.axes import Axes
from networkx import DiGraph

from hebg.behavior import Behavior
from hebg.call_graph import CallEdgeStatus, CallGraph
from hebg.codegen import get_hebg_source
from hebg.draw import draw_hebgraph
from hebg.graph import get_roots, get_successors_with_index
Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(
self.all_behaviors = all_behaviors if all_behaviors is not None else {}

self._unrolled_graph = None
self.call_graph: Optional[DiGraph] = None
self.call_graph: Optional[CallGraph] = None

super().__init__(incoming_graph_data=incoming_graph_data, **attr)

Expand Down Expand Up @@ -122,30 +121,19 @@ def unrolled_graph(self) -> HEBGraph:
def __call__(
self,
observation,
call_graph: Optional[DiGraph] = None,
call_graph: Optional[CallGraph] = None,
) -> Any:
if call_graph is None:
call_graph = DiGraph()
call_graph.graph["frontiere"] = []
call_graph.add_node(self.behavior.name, order=0)
call_graph = CallGraph(initial_node=self.behavior)

self.call_graph = call_graph
return self._split_call_between_nodes(
self.roots, observation, call_graph=call_graph
)

def _get_action(
self,
node: Node,
observation: Any,
call_graph: DiGraph,
parent_name: str,
):
def _get_action(self, node: Node, observation: Any, call_graph: DiGraph):
# Behavior
if node.type == "behavior":
# To avoid cycling definitions
if len(list(call_graph.successors(node.name))) > 0:
return "Impossible"

# Search for name reference in all_behaviors
if node.name in self.all_behaviors:
node = self.all_behaviors[node.name]
Expand All @@ -161,56 +149,33 @@ def _get_action(
next_edge_index = int(node(observation))
next_nodes = get_successors_with_index(self, node, next_edge_index)
return self._split_call_between_nodes(
next_nodes, observation, call_graph=call_graph, parent_name=node.name
next_nodes, observation, call_graph=call_graph, parent=node
)
# Empty
if node.type == "empty":
return self._split_call_between_nodes(
list(self.successors(node)),
observation,
call_graph=call_graph,
parent_name=node.name,
parent=node,
)
raise ValueError(f"Unknowed value {node.type} for node.type with node: {node}.")

def _split_call_between_nodes(
self,
nodes: List[Node],
observation,
call_graph: DiGraph,
parent_name: Optional[Node] = None,
call_graph: CallGraph,
parent: Optional[Node] = None,
) -> List[Action]:
if parent_name is None:
parent_name = self.behavior.name

frontiere: List[Node] = call_graph.graph["frontiere"]
frontiere.extend(nodes)

for node in nodes:
call_graph.add_edge(
parent_name, node.name, status=CallEdgeStatus.UNEXPLORED.value
)
node_data = call_graph.nodes[node.name]
parent_data = call_graph.nodes[parent_name]
if "order" not in node_data:
node_data["order"] = parent_data["order"] + 1

action = "Impossible"
while action == "Impossible" and len(frontiere) > 0:
lesser_complex_node = frontiere.pop(
np.argmin([node.cost for node in frontiere])
)

action = self._get_action(
lesser_complex_node, observation, call_graph, parent_name=parent_name
)

call_graph.edges[parent_name, lesser_complex_node.name]["status"] = (
CallEdgeStatus.FAILURE.value
if action == "Impossible"
else CallEdgeStatus.CALLED.value
)

if parent is None:
parent = self.behavior

call_graph.extend_frontiere(nodes, parent)
next_node = call_graph.pop_from_frontiere(parent)
if next_node is None:
raise ValueError("No valid frontiere left in call_graph")
action = self._get_action(next_node, observation, call_graph)
return action

@property
Expand Down Expand Up @@ -240,12 +205,6 @@ def draw(
return draw_hebgraph(self, ax, **kwargs)


class CallEdgeStatus(Enum):
UNEXPLORED = "unexplored"
CALLED = "called"
FAILURE = "failure"


def remove_duplicate_actions(actions: List[Action]) -> List[Action]:
seen = set()
seen_add = seen.add
Expand Down
18 changes: 12 additions & 6 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@

""" Tests for the heb_graph package. """

from typing import TYPE_CHECKING
from typing import Protocol
from matplotlib import pyplot as plt
import networkx as nx

if TYPE_CHECKING:
from hebg.heb_graph import HEBGraph

class Graph(Protocol):
def draw(self, ax, pos):
"""Draw the graph on a matplotlib axes."""

def plot_graph(graph: "HEBGraph"):
def nodes(self) -> list:
"""Return a list of nodes"""


def plot_graph(graph: Graph, **kwargs):
_, ax = plt.subplots()
pos = None
if len(graph.roots) == 0:
if len(list(graph.nodes())) == 0:
pos = nx.spring_layout(graph)
graph.draw(ax, pos=pos)
graph.draw(ax, pos=pos, **kwargs)
plt.axis("off") # turn off axis
plt.show()
54 changes: 6 additions & 48 deletions tests/test_call_graph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from typing import Union
from networkx import (
DiGraph,
draw_networkx_edges,
draw_networkx_labels,
draw_networkx_nodes,
)
from networkx import DiGraph

from hebg.behavior import Behavior
from hebg.heb_graph import CallEdgeStatus, HEBGraph
from hebg.heb_graph import HEBGraph
from hebg.node import Action

from pytest_mock import MockerFixture
Expand Down Expand Up @@ -137,6 +132,9 @@ def test_looping_goback(self):

call_graph = get_axe.graph.call_graph

if draw:
plot_graph(call_graph)

expected_order = [
"Get new axe",
"Has wood ?",
Expand All @@ -151,46 +149,6 @@ def test_looping_goback(self):
)
assert [node for node, _order in nodes_by_order] == expected_order

if draw:
import matplotlib.pyplot as plt

def status_to_color(status: Union[str, CallEdgeStatus]):
status = CallEdgeStatus(status)
if status is CallEdgeStatus.UNEXPLORED:
return "black"
if status is CallEdgeStatus.CALLED:
return "green"
if status is CallEdgeStatus.FAILURE:
return "red"
raise NotImplementedError

def call_graph_pos(call_graph: DiGraph):
pos = {}
amount_by_order = {}
for node, node_data in call_graph.nodes(data=True):
order: int = node_data.get("order")
if order not in amount_by_order:
amount_by_order[order] = 0
else:
amount_by_order[order] += 1
pos[node] = [order, amount_by_order[order] / 2]
return pos

pos = call_graph_pos(call_graph)
draw_networkx_nodes(call_graph, pos=pos)
draw_networkx_labels(call_graph, pos=pos)
draw_networkx_edges(
call_graph,
pos,
edge_color=[
status_to_color(status)
for _, _, status in call_graph.edges(data="status")
],
connectionstyle="arc3,rad=-0.15",
)
plt.axis("off") # turn off axis
plt.show()

expected_graph = DiGraph(
[
("Get new axe", "Has wood ?"),
Expand Down

0 comments on commit e6333d3

Please sign in to comment.