Skip to content

Commit

Permalink
Issue #150 CrossBackendSplitter: decouple graph splitting from SubJob…
Browse files Browse the repository at this point in the history
… yielding

Introduce ProcessGraphSplitterInterface, with first LoadCollectionGraphSplitter implementation based on existing simple graph splitting logic
  • Loading branch information
soxofaan committed Sep 16, 2024
1 parent aba411c commit 5302f74
Showing 1 changed file with 85 additions and 33 deletions.
118 changes: 85 additions & 33 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
import collections
import copy
import dataclasses
Expand All @@ -18,6 +19,7 @@
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Protocol,
Sequence,
Expand Down Expand Up @@ -45,6 +47,7 @@
_LOAD_RESULT_PLACEHOLDER = "_placeholder:"

# Some type annotation aliases to make things more self-documenting
CollectionId = str
SubGraphId = str
NodeId = str
BackendId = str
Expand Down Expand Up @@ -87,6 +90,75 @@ def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId)
}


class _SubGraphData(NamedTuple):
split_node: NodeId
node_ids: Set[NodeId]
backend_id: BackendId


class _PGSplitResult(NamedTuple):
primary_node_ids: Set[NodeId]
primary_backend_id: BackendId
secondary_graphs: List[_SubGraphData]


class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta):
@abc.abstractmethod
def split(self, process_graph: FlatPG) -> _PGSplitResult:
"""
Split given process graph (flat graph representation) into sub graphs
Returns primary graph data (node ids and backend id)
and secondary graphs data (list of tuples: split node id, subgraph node ids,backend id)
"""
...


class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface):
"""Simple process graph splitter that just splits off load_collection nodes"""

def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False):
# TODO: also support not not having a backend_for_collection map?
self._backend_for_collection = backend_for_collection
self._always_split = always_split

def split(self, process_graph: FlatPG) -> _PGSplitResult:
# Extract necessary back-ends from `load_collection` usage
backend_per_collection: Dict[str, str] = {
cid: self._backend_for_collection(cid)
for cid in (
node["arguments"]["id"] for node in process_graph.values() if node["process_id"] == "load_collection"
)
}
backend_usage = collections.Counter(backend_per_collection.values())
_log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}")

# TODO: more options to determine primary backend?
primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None
secondary_backends = {b for b in backend_usage if b != primary_backend}
_log.info(f"Backend split: {primary_backend=} {secondary_backends=}")

primary_has_load_collection = False
primary_graph_node_ids = set()
secondary_graphs: List[_SubGraphData] = []
for node_id, node in process_graph.items():
if node["process_id"] == "load_collection":
bid = backend_per_collection[node["arguments"]["id"]]
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
primary_graph_node_ids.add(node_id)
primary_has_load_collection = True
else:
secondary_graphs.append(_SubGraphData(split_node=node_id, node_ids={node_id}, backend_id=bid))
else:
primary_graph_node_ids.add(node_id)

return _PGSplitResult(
primary_node_ids=primary_graph_node_ids,
primary_backend_id=primary_backend,
secondary_graphs=secondary_graphs,
)


class CrossBackendSplitter(AbstractJobSplitter):
"""
Split a process graph, to be executed across multiple back-ends,
Expand All @@ -97,14 +169,15 @@ class CrossBackendSplitter(AbstractJobSplitter):
"""

def __init__(self, backend_for_collection: Callable[[str], str], always_split: bool = False):
def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False):
"""
:param backend_for_collection: callable that determines backend id for given collection id
:param always_split: split all load_collections, also when on same backend
"""
# TODO: just handle this `backend_for_collection` callback with a regular method?
self.backend_for_collection = backend_for_collection
self._always_split = always_split
# TODO: inject splitter instead of building it here
self._graph_splitter = LoadCollectionGraphSplitter(
backend_for_collection=backend_for_collection, always_split=always_split
)

def split_streaming(
self,
Expand All @@ -127,36 +200,12 @@ def split_streaming(
- dependencies as list of subgraph ids
"""

# Extract necessary back-ends from `load_collection` usage
backend_per_collection: Dict[str, str] = {
cid: self.backend_for_collection(cid)
for cid in (
node["arguments"]["id"] for node in process_graph.values() if node["process_id"] == "load_collection"
)
}
backend_usage = collections.Counter(backend_per_collection.values())
_log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}")

# TODO: more options to determine primary backend?
primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None
secondary_backends = {b for b in backend_usage if b != primary_backend}
_log.info(f"Backend split: {primary_backend=} {secondary_backends=}")
graph_split_result = self._graph_splitter.split(process_graph=process_graph)

primary_has_load_collection = False
sub_graphs: List[Tuple[NodeId, Set[NodeId], BackendId]] = []
for node_id, node in process_graph.items():
if node["process_id"] == "load_collection":
bid = backend_per_collection[node["arguments"]["id"]]
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
primary_has_load_collection = True
else:
sub_graphs.append((node_id, {node_id}, bid))

primary_graph_node_ids = set(process_graph.keys()).difference(n for _, ns, _ in sub_graphs for n in ns)
primary_pg = {k: process_graph[k] for k in primary_graph_node_ids}
primary_pg = {k: process_graph[k] for k in graph_split_result.primary_node_ids}
primary_dependencies = []

for node_id, subgraph_node_ids, backend_id in sub_graphs:
for node_id, subgraph_node_ids, backend_id in graph_split_result.secondary_graphs:
# New secondary pg
sub_id = f"{backend_id}:{node_id}"
sub_pg = {k: v for k, v in process_graph.items() if k in subgraph_node_ids}
Expand All @@ -178,8 +227,11 @@ def split_streaming(
primary_pg.update(get_replacement(node_id=node_id, node=process_graph[node_id], subgraph_id=sub_id))
primary_dependencies.append(sub_id)

primary_id = main_subgraph_id
yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies)
yield (
main_subgraph_id,
SubJob(process_graph=primary_pg, backend_id=graph_split_result.primary_backend_id),
primary_dependencies,
)

def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob:
"""Split given process graph into a `PartitionedJob`"""
Expand Down

0 comments on commit 5302f74

Please sign in to comment.