Skip to content

Commit

Permalink
Issue #150 basic implementation of DeepGraphSplitter
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 17, 2024
1 parent 0977fd6 commit a1a9b64
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 50 deletions.
4 changes: 2 additions & 2 deletions scripts/crossbackend-processing-poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from openeo_aggregator.metadata import STAC_PROPERTY_FEDERATION_BACKENDS
from openeo_aggregator.partitionedjobs import PartitionedJob
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
CrossBackendJobSplitter,
LoadCollectionGraphSplitter,
run_partitioned_job,
)
Expand Down Expand Up @@ -63,7 +63,7 @@ def backend_for_collection(collection_id) -> str:
metadata = connection.describe_collection(collection_id)
return metadata["summaries"][STAC_PROPERTY_FEDERATION_BACKENDS][0]

splitter = CrossBackendSplitter(
splitter = CrossBackendJobSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=backend_for_collection, always_split=True)
)
pjob: PartitionedJob = splitter.split({"process_graph": process_graph})
Expand Down
4 changes: 2 additions & 2 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
)
from openeo_aggregator.partitionedjobs import PartitionedJob
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
CrossBackendJobSplitter,
LoadCollectionGraphSplitter,
)
from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter
Expand Down Expand Up @@ -942,7 +942,7 @@ def _create_crossbackend_job(
def backend_for_collection(collection_id) -> str:
return self._catalog.get_backends_for_collection(cid=collection_id)[0]

splitter = CrossBackendSplitter(
splitter = CrossBackendJobSplitter(
graph_splitter=LoadCollectionGraphSplitter(
backend_for_collection=backend_for_collection,
# TODO: job option for `always_split` feature?
Expand Down
121 changes: 104 additions & 17 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ class _PGSplitResult(NamedTuple):


class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta):
"""
Interface for process graph splitters:
given a process graph (flat graph representation),
produce a main graph and secondary graphs (as subsets of node ids)
and the backends they are supposed to run on.
"""

@abc.abstractmethod
def split(self, process_graph: FlatPG) -> _PGSplitResult:
"""
Expand All @@ -115,7 +122,9 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult:


class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface):
"""Simple process graph splitter that just splits off load_collection nodes"""
"""
Simple process graph splitter that just splits off load_collection nodes.
"""

def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False):
# TODO: also support not not having a backend_for_collection map?
Expand Down Expand Up @@ -159,7 +168,7 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult:
)


class CrossBackendSplitter(AbstractJobSplitter):
class CrossBackendJobSplitter(AbstractJobSplitter):
"""
Split a process graph, to be executed across multiple back-ends,
based on availability of collections.
Expand Down Expand Up @@ -542,6 +551,8 @@ def from_edges(
return cls(graph=graph)

def node(self, node_id: NodeId) -> _FrozenNode:
if node_id not in self._graph:
raise GraphSplitException(f"Invalid node id {node_id}.")
return self._graph[node_id]

def iter_nodes(self) -> Iterator[Tuple[NodeId, _FrozenNode]]:
Expand Down Expand Up @@ -584,28 +595,35 @@ def walk_downstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = T
"""
return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).flows_to, include_seeds=include_seeds)

def get_backend_candidates(self, node_id: NodeId) -> Union[frozenset[BackendId], None]:
def get_backend_candidates_for_node(self, node_id: NodeId) -> Union[frozenset[BackendId], None]:
"""Determine backend candidates for given node id"""
# TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated)
if self.node(node_id).backend_candidates is not None:
# Node has explicit backend candidates listed
return self.node(node_id).backend_candidates
elif self.node(node_id).depends_on:
# Backend support is unset: determine it (as intersection) from upstream nodes
# TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated)
upstream_candidates = (self.get_backend_candidates(n) for n in self.node(node_id).depends_on)
upstream_candidates = [c for c in upstream_candidates if c is not None]
if upstream_candidates:
return functools.reduce(lambda a, b: a.intersection(b), upstream_candidates)
else:
return None
return self.get_backend_candidates_for_node_set(self.node(node_id).depends_on)
else:
return None

def get_backend_candidates_for_node_set(self, node_ids: Iterable[NodeId]) -> Union[frozenset[BackendId], None]:
"""
Determine backend candidates for a set of nodes
"""
candidates = set(self.get_backend_candidates_for_node(n) for n in node_ids)
if candidates == {None}:
return None
candidates.discard(None)
return functools.reduce(lambda a, b: a.intersection(b), candidates)

def find_forsaken_nodes(self) -> Set[NodeId]:
"""
Find nodes that have no backend candidates to process them
"""
return set(node_id for (node_id, _) in self.iter_nodes() if self.get_backend_candidates(node_id) == set())
return set(
node_id for (node_id, _) in self.iter_nodes() if self.get_backend_candidates_for_node(node_id) == set()
)

def find_articulation_points(self) -> Set[NodeId]:
"""
Expand Down Expand Up @@ -652,12 +670,11 @@ def split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]:
"""
Split graph at given node id (must be articulation point),
creating two new graphs, containing original nodes and adaptation of the split node.
:return: two _FrozenGraph objects: the upstream subgraph and the downstream subgraph
"""
split_node = self.node(split_node_id)

# TODO: first verify that node_id is a valid articulation point?
# Or let this fail, e.g. in validation of _FrozenGraph.__init__?

# Walk the graph, upstream from the split node
def next_nodes(node_id: NodeId) -> Iterable[NodeId]:
node = self.node(node_id)
Expand Down Expand Up @@ -687,11 +704,22 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]:
)
down = _FrozenGraph(graph=down_graph)

return down, up
return up, down

def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]:
"""
Produce disjoint subgraphs that can be processed independently
Produce disjoint subgraphs that can be processed independently.
:return: iterator of node listings.
Each node listing encodes a graph split (nodes ids where to split).
A node listing is ordered with the following in mind:
- the first node id does a first split in a downstream and upstream part.
The upstream part can be handled by a single backend.
The downstream part is not necessarily covered by a single backend,
in which case one or more additional splits will be necessary.
- the second node id does a second split of the downstream part of
the previous split.
- etc
"""
# Find nodes that have empty set of backend_candidates
forsaken_nodes = self.find_forsaken_nodes()
Expand All @@ -705,6 +733,8 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]:
# Collect nodes where we could split the graph in disjoint subgraphs
articulation_points: Set[NodeId] = set(self.find_articulation_points())

# TODO: allow/deny lists of what openEO processes can be split on? E.g. only split raster cube paths

# Walk upstream from forsaken nodes to find articulation points, where we can cut
split_options = [
n
Expand All @@ -717,12 +747,13 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]:
# TODO: smarter picking of split node (e.g. one with most upstream nodes)
for split_node_id in split_options[:limit]:
# Split graph at this articulation point
down, up = self.split_at(split_node_id)
up, down = self.split_at(split_node_id)
if down.find_forsaken_nodes():
down_splits = list(down.produce_split_locations(limit=limit - 1))
else:
down_splits = [[]]
if up.find_forsaken_nodes():
# TODO: will this actually happen? the upstream sub-graph should be single-backend by design?
up_splits = list(up.produce_split_locations(limit=limit - 1))
else:
up_splits = [[]]
Expand All @@ -733,3 +764,59 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]:
else:
# All nodes can be handled as is, no need to split
yield []


class DeepGraphSplitter(ProcessGraphSplitterInterface):
"""
More advanced graph splitting (compared to just splitting off `load_collection` nodes)
"""

# TODO: unify:
# - backend_for_collection: Callable[[CollectionId], BackendId]
# - backend_candidates_map: Dict[NodeId, Iterable[BackendId]]
# Note that the nodeid-backendid mapping smells like bad decoupling
# as the process graph is given to split methods, while mapping to __init__
# TODO: validation for Iterable[BackendId] (avoid passing a single string instead of iterable of strings)
def __init__(self, backend_candidates_map: Dict[NodeId, Iterable[BackendId]]):
self._backend_candidates_map = backend_candidates_map

def split(self, process_graph: FlatPG) -> _PGSplitResult:
graph = _FrozenGraph.from_flat_graph(
flat_graph=process_graph, backend_candidates_map=self._backend_candidates_map
)

# TODO: make picking "optimal" split location set a bit more deterministic (e.g. sort first)
(split_nodes,) = graph.produce_split_locations(limit=1)

secondary_graphs: List[_SubGraphData] = []
graph_to_split = graph
for split_node_id in split_nodes:
up, down = graph_to_split.split_at(split_node_id=split_node_id)
# Use upstream graph as secondary graph
node_ids = set(nid for nid, _ in up.iter_nodes())
backend_candidates = up.get_backend_candidates_for_node_set(node_ids)
# TODO: better backend selection?
# TODO handle case where backend_candidates is None?
backend_id = sorted(backend_candidates)[0]
secondary_graphs.append(
_SubGraphData(
split_node=split_node_id,
node_ids=node_ids,
backend_id=backend_id,
)
)

# Prepare for next split (if any)
graph_to_split = down

# Remaining graph is primary graph
primary_graph = graph_to_split
primary_node_ids = set(n for n, _ in primary_graph.iter_nodes())
backend_candidates = primary_graph.get_backend_candidates_for_node_set(primary_node_ids)
primary_backend_id = sorted(backend_candidates)[0]

return _PGSplitResult(
primary_node_ids=primary_node_ids,
primary_backend_id=primary_backend_id,
secondary_graphs=secondary_graphs,
)
4 changes: 2 additions & 2 deletions src/openeo_aggregator/partitionedjobs/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SubJob,
)
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
CrossBackendJobSplitter,
SubGraphId,
)
from openeo_aggregator.partitionedjobs.splitting import TileGridSplitter
Expand Down Expand Up @@ -71,7 +71,7 @@ def create_crossbackend_pjob(
process: PGWithMetadata,
metadata: dict,
job_options: Optional[dict] = None,
splitter: CrossBackendSplitter,
splitter: CrossBackendJobSplitter,
) -> str:
"""
crossbackend partitioned job creation is different from original partitioned
Expand Down
Loading

0 comments on commit a1a9b64

Please sign in to comment.