Skip to content

Commit

Permalink
Issue #150/#155 extract split_at_multiple logic from DeepGraphSplitter
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 20, 2024
1 parent 835b8be commit de700f0
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 44 deletions.
90 changes: 55 additions & 35 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId)
}


class _SubGraphData(NamedTuple):
class _PGSplitSubGraph(NamedTuple):
"""Container for result of ProcessGraphSplitterInterface.split"""

split_node: NodeId
Expand All @@ -109,7 +109,7 @@ class _PGSplitResult(NamedTuple):

primary_node_ids: Set[NodeId]
primary_backend_id: BackendId
secondary_graphs: List[_SubGraphData]
secondary_graphs: List[_PGSplitSubGraph]


class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -161,15 +161,15 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult:

primary_has_load_collection = False
primary_graph_node_ids = set()
secondary_graphs: List[_SubGraphData] = []
secondary_graphs: List[_PGSplitSubGraph] = []
for node_id, node in process_graph.items():
if node["process_id"] == "load_collection":
bid = backend_per_collection[node["arguments"]["id"]]
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
primary_graph_node_ids.add(node_id)
primary_has_load_collection = True
else:
secondary_graphs.append(_SubGraphData(split_node=node_id, node_ids={node_id}, backend_id=bid))
secondary_graphs.append(_PGSplitSubGraph(split_node=node_id, node_ids={node_id}, backend_id=bid))
else:
primary_graph_node_ids.add(node_id)

Expand Down Expand Up @@ -579,7 +579,7 @@ def from_edges(

def node(self, node_id: NodeId) -> _GVNode:
if node_id not in self._graph:
raise GraphSplitException(f"Invalid node id {node_id}.")
raise GraphSplitException(f"Invalid node id {node_id!r}.")
return self._graph[node_id]

def iter_nodes(self) -> Iterator[Tuple[NodeId, _GVNode]]:
Expand Down Expand Up @@ -712,7 +712,7 @@ def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]:
def split_at(self, split_node_id: NodeId) -> Tuple[_GraphViewer, _GraphViewer]:
"""
Split graph at given node id (must be articulation point),
creating two new graphs, containing original nodes and adaptation of the split node.
creating two new graph viewers, containing original nodes and adaptation of the split node.
:return: two _GraphViewer objects: the upstream subgraph and the downstream subgraph
"""
Expand All @@ -729,7 +729,7 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]:
up_node_ids = set(self._walk(seeds=[split_node_id], next_nodes=next_nodes))

if split_node.flows_to.intersection(up_node_ids):
raise GraphSplitException(f"Graph can not be split at {split_node_id}: not an articulation point.")
raise GraphSplitException(f"Graph can not be split at {split_node_id!r}: not an articulation point.")

up_graph = {n: self.node(n) for n in up_node_ids}
# Replacement of original split node: no `flows_to` links
Expand Down Expand Up @@ -810,6 +810,26 @@ def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]:
# All nodes can be handled as is, no need to split
yield []

def split_at_multiple(self, split_nodes: List[NodeId]) -> Dict[Union[NodeId, None], _GraphViewer]:
"""
Split the graph viewer at multiple nodes in the order as provided.
Each split produces an upstream and downstream graph,
the downstream graph is used for the next split,
so the split nodes should be ordered as such.
Returns dictionary with:
- key: split node_ids or None for the final downstream graph
- value: corresponding sub graph viewers as values.
"""
result = {}
graph_to_split = self
for split_node_id in split_nodes:
up, down = graph_to_split.split_at(split_node_id=split_node_id)
result[split_node_id] = up
graph_to_split = down
result[None] = graph_to_split
return result


class DeepGraphSplitter(ProcessGraphSplitterInterface):
"""
Expand All @@ -820,6 +840,16 @@ def __init__(self, supporting_backends: SupportingBackendsMapper, primary_backen
self._supporting_backends_mapper = supporting_backends
self._primary_backend = primary_backend

def _pick_backend(self, backend_candidates: Union[frozenset[BackendId], None]) -> BackendId:
if backend_candidates is None:
if self._primary_backend:
return self._primary_backend
else:
raise GraphSplitException("DeepGraphSplitter._pick_backend: No backend candidates.")
else:
# TODO: better backend selection mechanism
return sorted(backend_candidates)[0]

def split(self, process_graph: FlatPG) -> _PGSplitResult:
graph = _GraphViewer.from_flat_graph(
flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper
Expand All @@ -828,36 +858,26 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult:
for split_nodes in graph.produce_split_locations():
_log.debug(f"DeepGraphSplitter.split: evaluating split nodes: {split_nodes=}")

secondary_graphs: List[_SubGraphData] = []
graph_to_split = graph
for split_node_id in split_nodes:
up, down = graph_to_split.split_at(split_node_id=split_node_id)
# Use upstream graph as secondary graph
node_ids = set(nid for nid, _ in up.iter_nodes())
backend_candidates = up.get_backend_candidates_for_node_set(node_ids)
# TODO: better backend selection?
# TODO handle case where backend_candidates is None?
backend_id = sorted(backend_candidates)[0]
_log.debug(
f"DeepGraphSplitter.split: secondary graph: from {split_node_id=}: {backend_id=} {node_ids=}"
)
secondary_graphs.append(
_SubGraphData(
split_node=split_node_id,
node_ids=node_ids,
backend_id=backend_id,
)
)
split_views = graph.split_at_multiple(split_nodes=split_nodes)

# Prepare for next split (if any)
graph_to_split = down
# Extract nodes and backend ids for each subgraph
subgraph_node_ids = {k: set(n for n, _ in v.iter_nodes()) for k, v in split_views.items()}
subgraph_backend_ids = {
k: self._pick_backend(backend_candidates=v.get_backend_candidates_for_node_set(subgraph_node_ids[k]))
for k, v in split_views.items()
}
_log.debug(f"DeepGraphSplitter.split: {subgraph_node_ids=} {subgraph_backend_ids=}")

# Handle primary graph
split_views.pop(None)
primary_node_ids = subgraph_node_ids.pop(None)
primary_backend_id = subgraph_backend_ids.pop(None)

# Remaining graph is primary graph
primary_graph = graph_to_split
primary_node_ids = set(n for n, _ in primary_graph.iter_nodes())
backend_candidates = primary_graph.get_backend_candidates_for_node_set(primary_node_ids)
primary_backend_id = sorted(backend_candidates)[0]
_log.debug(f"DeepGraphSplitter.split: primary graph: {primary_backend_id=} {primary_node_ids=}")
# Handle secondary graphs
secondary_graphs = [
_PGSplitSubGraph(split_node=k, node_ids=subgraph_node_ids[k], backend_id=subgraph_backend_ids[k])
for k in split_views.keys()
]

if self._primary_backend is None or primary_backend_id == self._primary_backend:
_log.debug(f"DeepGraphSplitter.split: current split matches constraints")
Expand Down
65 changes: 56 additions & 9 deletions tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_GraphViewer,
_GVNode,
_PGSplitResult,
_SubGraphData,
_PGSplitSubGraph,
run_partitioned_job,
)

Expand Down Expand Up @@ -786,6 +786,13 @@ def test_split_at_basic(self):

def test_split_at_complex(self):
graph = _GraphViewer.from_edges(
# a
# / \
# b c X
# \ / \ |
# d e f Y
# \ /
# g
[("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e"), ("e", "g"), ("f", "g"), ("X", "Y")]
)
up, down = graph.split_at("e")
Expand Down Expand Up @@ -830,6 +837,44 @@ def test_split_at_non_articulation_point(self):
("c", _GVNode()),
]

def test_split_at_multiple_empty(self):
graph = _GraphViewer.from_edges([("a", "b")])
result = graph.split_at_multiple([])
assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == {
None: [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))],
}

def test_split_at_multiple_single(self):
graph = _GraphViewer.from_edges([("a", "b"), ("b", "c")])
result = graph.split_at_multiple(["b"])
assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == {
"b": [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))],
None: [("b", _GVNode(flows_to="c")), ("c", _GVNode(depends_on="b"))],
}

def test_split_at_multiple_basic(self):
graph = _GraphViewer.from_edges(
[("a", "b"), ("b", "c"), ("c", "d")],
supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}),
)
result = graph.split_at_multiple(["b", "c"])
assert {n: sorted(g.iter_nodes()) for (n, g) in result.items()} == {
"b": [("a", _GVNode(flows_to="b", backend_candidates="A")), ("b", _GVNode(depends_on="a"))],
"c": [("b", _GVNode(flows_to="c")), ("c", _GVNode(depends_on="b"))],
None: [("c", _GVNode(flows_to="d")), ("d", _GVNode(depends_on="c"))],
}

def test_split_at_multiple_invalid(self):
"""Split nodes should be in downstream order"""
graph = _GraphViewer.from_edges(
[("a", "b"), ("b", "c"), ("c", "d")],
)
# Downstream order: works
_ = graph.split_at_multiple(["b", "c"])
# Upstream order: fails
with pytest.raises(GraphSplitException, match="Invalid node id 'b'"):
_ = graph.split_at_multiple(["c", "b"])

def test_produce_split_locations_simple(self):
"""Simple produce_split_locations use case: no need for splits"""
flat = {
Expand Down Expand Up @@ -956,7 +1001,7 @@ def test_simple_split(self):
primary_node_ids={"lc1", "lc2", "merge"},
primary_backend_id="b2",
secondary_graphs=[
_SubGraphData(
_PGSplitSubGraph(
split_node="lc1",
node_ids={"lc1"},
backend_id="b1",
Expand Down Expand Up @@ -995,7 +1040,7 @@ def test_simple_deep_split(self):
assert result == _PGSplitResult(
primary_node_ids={"lc2", "temporal2", "bands1", "merge"},
primary_backend_id="b2",
secondary_graphs=[_SubGraphData(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")],
secondary_graphs=[_PGSplitSubGraph(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")],
)

def test_shallow_triple_split(self):
Expand Down Expand Up @@ -1026,8 +1071,8 @@ def test_shallow_triple_split(self):
primary_node_ids={"lc1", "lc2", "lc3", "merge1", "merge2"},
primary_backend_id="b2",
secondary_graphs=[
_SubGraphData(split_node="lc1", node_ids={"lc1"}, backend_id="b1"),
_SubGraphData(split_node="lc3", node_ids={"lc3"}, backend_id="b3"),
_PGSplitSubGraph(split_node="lc1", node_ids={"lc1"}, backend_id="b1"),
_PGSplitSubGraph(split_node="lc3", node_ids={"lc3"}, backend_id="b3"),
],
)

Expand Down Expand Up @@ -1067,16 +1112,18 @@ def test_triple_split(self):
primary_node_ids={"merge2", "merge1", "lc3", "spatial3"},
primary_backend_id="b3",
secondary_graphs=[
_SubGraphData(split_node="bands1", node_ids={"bands1", "lc1"}, backend_id="b1"),
_SubGraphData(split_node="merge1", node_ids={"bands1", "merge1", "temporal2", "lc2"}, backend_id="b2"),
_PGSplitSubGraph(split_node="bands1", node_ids={"bands1", "lc1"}, backend_id="b1"),
_PGSplitSubGraph(
split_node="merge1", node_ids={"bands1", "merge1", "temporal2", "lc2"}, backend_id="b2"
),
],
)

@pytest.mark.parametrize(
["primary_backend", "secondary_graph"],
[
("b1", _SubGraphData(split_node="lc2", node_ids={"lc2"}, backend_id="b2")),
("b2", _SubGraphData(split_node="lc1", node_ids={"lc1"}, backend_id="b1")),
("b1", _PGSplitSubGraph(split_node="lc2", node_ids={"lc2"}, backend_id="b2")),
("b2", _PGSplitSubGraph(split_node="lc1", node_ids={"lc1"}, backend_id="b1")),
],
)
def test_split_with_primary_backend(self, primary_backend, secondary_graph):
Expand Down

0 comments on commit de700f0

Please sign in to comment.