diff --git a/scripts/crossbackend-processing-poc.py b/scripts/crossbackend-processing-poc.py index e8a138c..01f6a33 100644 --- a/scripts/crossbackend-processing-poc.py +++ b/scripts/crossbackend-processing-poc.py @@ -7,7 +7,7 @@ from openeo_aggregator.metadata import STAC_PROPERTY_FEDERATION_BACKENDS from openeo_aggregator.partitionedjobs import PartitionedJob from openeo_aggregator.partitionedjobs.crossbackend import ( - CrossBackendSplitter, + CrossBackendJobSplitter, LoadCollectionGraphSplitter, run_partitioned_job, ) @@ -63,7 +63,7 @@ def backend_for_collection(collection_id) -> str: metadata = connection.describe_collection(collection_id) return metadata["summaries"][STAC_PROPERTY_FEDERATION_BACKENDS][0] - splitter = CrossBackendSplitter( + splitter = CrossBackendJobSplitter( graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=backend_for_collection, always_split=True) ) pjob: PartitionedJob = splitter.split({"process_graph": process_graph}) diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index 541ab93..b4b8aa3 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -101,7 +101,7 @@ ) from openeo_aggregator.partitionedjobs import PartitionedJob from openeo_aggregator.partitionedjobs.crossbackend import ( - CrossBackendSplitter, + CrossBackendJobSplitter, LoadCollectionGraphSplitter, ) from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter @@ -942,7 +942,7 @@ def _create_crossbackend_job( def backend_for_collection(collection_id) -> str: return self._catalog.get_backends_for_collection(cid=collection_id)[0] - splitter = CrossBackendSplitter( + splitter = CrossBackendJobSplitter( graph_splitter=LoadCollectionGraphSplitter( backend_for_collection=backend_for_collection, # TODO: job option for `always_split` feature? diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 3163e79..9af87e2 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -103,6 +103,13 @@ class _PGSplitResult(NamedTuple): class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta): + """ + Interface for process graph splitters: + given a process graph (flat graph representation), + produce a main graph and secondary graphs (as subsets of node ids) + and the backends they are supposed to run on. + """ + @abc.abstractmethod def split(self, process_graph: FlatPG) -> _PGSplitResult: """ @@ -115,7 +122,9 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult: class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface): - """Simple process graph splitter that just splits off load_collection nodes""" + """ + Simple process graph splitter that just splits off load_collection nodes. + """ def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False): # TODO: also support not not having a backend_for_collection map? @@ -159,7 +168,7 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult: ) -class CrossBackendSplitter(AbstractJobSplitter): +class CrossBackendJobSplitter(AbstractJobSplitter): """ Split a process graph, to be executed across multiple back-ends, based on availability of collections. @@ -542,6 +551,8 @@ def from_edges( return cls(graph=graph) def node(self, node_id: NodeId) -> _FrozenNode: + if node_id not in self._graph: + raise GraphSplitException(f"Invalid node id {node_id}.") return self._graph[node_id] def iter_nodes(self) -> Iterator[Tuple[NodeId, _FrozenNode]]: @@ -584,28 +595,35 @@ def walk_downstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = T """ return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).flows_to, include_seeds=include_seeds) - def get_backend_candidates(self, node_id: NodeId) -> Union[frozenset[BackendId], None]: + def get_backend_candidates_for_node(self, node_id: NodeId) -> Union[frozenset[BackendId], None]: """Determine backend candidates for given node id""" + # TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated) if self.node(node_id).backend_candidates is not None: # Node has explicit backend candidates listed return self.node(node_id).backend_candidates elif self.node(node_id).depends_on: # Backend support is unset: determine it (as intersection) from upstream nodes - # TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated) - upstream_candidates = (self.get_backend_candidates(n) for n in self.node(node_id).depends_on) - upstream_candidates = [c for c in upstream_candidates if c is not None] - if upstream_candidates: - return functools.reduce(lambda a, b: a.intersection(b), upstream_candidates) - else: - return None + return self.get_backend_candidates_for_node_set(self.node(node_id).depends_on) else: return None + def get_backend_candidates_for_node_set(self, node_ids: Iterable[NodeId]) -> Union[frozenset[BackendId], None]: + """ + Determine backend candidates for a set of nodes + """ + candidates = set(self.get_backend_candidates_for_node(n) for n in node_ids) + if candidates == {None}: + return None + candidates.discard(None) + return functools.reduce(lambda a, b: a.intersection(b), candidates) + def find_forsaken_nodes(self) -> Set[NodeId]: """ Find nodes that have no backend candidates to process them """ - return set(node_id for (node_id, _) in self.iter_nodes() if self.get_backend_candidates(node_id) == set()) + return set( + node_id for (node_id, _) in self.iter_nodes() if self.get_backend_candidates_for_node(node_id) == set() + ) def find_articulation_points(self) -> Set[NodeId]: """ @@ -652,12 +670,11 @@ def split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]: """ Split graph at given node id (must be articulation point), creating two new graphs, containing original nodes and adaptation of the split node. + + :return: two _FrozenGraph objects: the upstream subgraph and the downstream subgraph """ split_node = self.node(split_node_id) - # TODO: first verify that node_id is a valid articulation point? - # Or let this fail, e.g. in validation of _FrozenGraph.__init__? - # Walk the graph, upstream from the split node def next_nodes(node_id: NodeId) -> Iterable[NodeId]: node = self.node(node_id) @@ -687,11 +704,22 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]: ) down = _FrozenGraph(graph=down_graph) - return down, up + return up, down def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]: """ - Produce disjoint subgraphs that can be processed independently + Produce disjoint subgraphs that can be processed independently. + + :return: iterator of node listings. + Each node listing encodes a graph split (nodes ids where to split). + A node listing is ordered with the following in mind: + - the first node id does a first split in a downstream and upstream part. + The upstream part can be handled by a single backend. + The downstream part is not necessarily covered by a single backend, + in which case one or more additional splits will be necessary. + - the second node id does a second split of the downstream part of + the previous split. + - etc """ # Find nodes that have empty set of backend_candidates forsaken_nodes = self.find_forsaken_nodes() @@ -705,6 +733,8 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]: # Collect nodes where we could split the graph in disjoint subgraphs articulation_points: Set[NodeId] = set(self.find_articulation_points()) + # TODO: allow/deny lists of what openEO processes can be split on? E.g. only split raster cube paths + # Walk upstream from forsaken nodes to find articulation points, where we can cut split_options = [ n @@ -717,12 +747,13 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]: # TODO: smarter picking of split node (e.g. one with most upstream nodes) for split_node_id in split_options[:limit]: # Split graph at this articulation point - down, up = self.split_at(split_node_id) + up, down = self.split_at(split_node_id) if down.find_forsaken_nodes(): down_splits = list(down.produce_split_locations(limit=limit - 1)) else: down_splits = [[]] if up.find_forsaken_nodes(): + # TODO: will this actually happen? the upstream sub-graph should be single-backend by design? up_splits = list(up.produce_split_locations(limit=limit - 1)) else: up_splits = [[]] @@ -733,3 +764,59 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]: else: # All nodes can be handled as is, no need to split yield [] + + +class DeepGraphSplitter(ProcessGraphSplitterInterface): + """ + More advanced graph splitting (compared to just splitting off `load_collection` nodes) + """ + + # TODO: unify: + # - backend_for_collection: Callable[[CollectionId], BackendId] + # - backend_candidates_map: Dict[NodeId, Iterable[BackendId]] + # Note that the nodeid-backendid mapping smells like bad decoupling + # as the process graph is given to split methods, while mapping to __init__ + # TODO: validation for Iterable[BackendId] (avoid passing a single string instead of iterable of strings) + def __init__(self, backend_candidates_map: Dict[NodeId, Iterable[BackendId]]): + self._backend_candidates_map = backend_candidates_map + + def split(self, process_graph: FlatPG) -> _PGSplitResult: + graph = _FrozenGraph.from_flat_graph( + flat_graph=process_graph, backend_candidates_map=self._backend_candidates_map + ) + + # TODO: make picking "optimal" split location set a bit more deterministic (e.g. sort first) + (split_nodes,) = graph.produce_split_locations(limit=1) + + secondary_graphs: List[_SubGraphData] = [] + graph_to_split = graph + for split_node_id in split_nodes: + up, down = graph_to_split.split_at(split_node_id=split_node_id) + # Use upstream graph as secondary graph + node_ids = set(nid for nid, _ in up.iter_nodes()) + backend_candidates = up.get_backend_candidates_for_node_set(node_ids) + # TODO: better backend selection? + # TODO handle case where backend_candidates is None? + backend_id = sorted(backend_candidates)[0] + secondary_graphs.append( + _SubGraphData( + split_node=split_node_id, + node_ids=node_ids, + backend_id=backend_id, + ) + ) + + # Prepare for next split (if any) + graph_to_split = down + + # Remaining graph is primary graph + primary_graph = graph_to_split + primary_node_ids = set(n for n, _ in primary_graph.iter_nodes()) + backend_candidates = primary_graph.get_backend_candidates_for_node_set(primary_node_ids) + primary_backend_id = sorted(backend_candidates)[0] + + return _PGSplitResult( + primary_node_ids=primary_node_ids, + primary_backend_id=primary_backend_id, + secondary_graphs=secondary_graphs, + ) diff --git a/src/openeo_aggregator/partitionedjobs/tracking.py b/src/openeo_aggregator/partitionedjobs/tracking.py index a34f36e..26ce77a 100644 --- a/src/openeo_aggregator/partitionedjobs/tracking.py +++ b/src/openeo_aggregator/partitionedjobs/tracking.py @@ -24,7 +24,7 @@ SubJob, ) from openeo_aggregator.partitionedjobs.crossbackend import ( - CrossBackendSplitter, + CrossBackendJobSplitter, SubGraphId, ) from openeo_aggregator.partitionedjobs.splitting import TileGridSplitter @@ -71,7 +71,7 @@ def create_crossbackend_pjob( process: PGWithMetadata, metadata: dict, job_options: Optional[dict] = None, - splitter: CrossBackendSplitter, + splitter: CrossBackendJobSplitter, ) -> str: """ crossbackend partitioned job creation is different from original partitioned diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 014ea4b..f1e774f 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -14,12 +14,15 @@ from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob from openeo_aggregator.partitionedjobs.crossbackend import ( - CrossBackendSplitter, + CrossBackendJobSplitter, + DeepGraphSplitter, GraphSplitException, LoadCollectionGraphSplitter, SubGraphId, _FrozenGraph, _FrozenNode, + _PGSplitResult, + _SubGraphData, run_partitioned_job, ) @@ -27,7 +30,7 @@ class TestCrossBackendSplitter: def test_split_simple(self): process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - splitter = CrossBackendSplitter( + splitter = CrossBackendJobSplitter( graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo") ) res = splitter.split({"process_graph": process_graph}) @@ -37,7 +40,7 @@ def test_split_simple(self): def test_split_streaming_simple(self): process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - splitter = CrossBackendSplitter( + splitter = CrossBackendJobSplitter( graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo") ) res = splitter.split_streaming(process_graph) @@ -61,7 +64,7 @@ def test_split_basic(self): "result": True, }, } - splitter = CrossBackendSplitter( + splitter = CrossBackendJobSplitter( graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) ) res = splitter.split({"process_graph": process_graph}) @@ -126,7 +129,7 @@ def test_split_streaming_basic(self): "result": True, }, } - splitter = CrossBackendSplitter( + splitter = CrossBackendJobSplitter( graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) ) result = splitter.split_streaming(process_graph) @@ -188,7 +191,7 @@ def test_split_streaming_get_replacement(self): "result": True, }, } - splitter = CrossBackendSplitter( + splitter = CrossBackendJobSplitter( graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) ) @@ -386,7 +389,7 @@ def test_basic(self, aggregator: _FakeAggregator): "result": True, }, } - splitter = CrossBackendSplitter( + splitter = CrossBackendJobSplitter( graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) ) pjob: PartitionedJob = splitter.split({"process_graph": process_graph}) @@ -512,32 +515,47 @@ def test_get_backend_candidates_basic(self): [("a", "b"), ("b", "d"), ("c", "d")], backend_candidates_map={"a": ["b1"], "c": ["b2"]}, ) - assert graph.get_backend_candidates("a") == {"b1"} - assert graph.get_backend_candidates("b") == {"b1"} - assert graph.get_backend_candidates("c") == {"b2"} - assert graph.get_backend_candidates("d") == set() + assert graph.get_backend_candidates_for_node("a") == {"b1"} + assert graph.get_backend_candidates_for_node("b") == {"b1"} + assert graph.get_backend_candidates_for_node("c") == {"b2"} + assert graph.get_backend_candidates_for_node("d") == set() + + assert graph.get_backend_candidates_for_node_set(["a"]) == {"b1"} + assert graph.get_backend_candidates_for_node_set(["b"]) == {"b1"} + assert graph.get_backend_candidates_for_node_set(["c"]) == {"b2"} + assert graph.get_backend_candidates_for_node_set(["d"]) == set() + assert graph.get_backend_candidates_for_node_set(["a", "b"]) == {"b1"} + assert graph.get_backend_candidates_for_node_set(["a", "b", "c"]) == set() + assert graph.get_backend_candidates_for_node_set(["a", "b", "d"]) == set() def test_get_backend_candidates_none(self): graph = _FrozenGraph.from_edges( [("a", "b"), ("b", "d"), ("c", "d")], backend_candidates_map={}, ) - assert graph.get_backend_candidates("a") is None - assert graph.get_backend_candidates("b") is None - assert graph.get_backend_candidates("c") is None - assert graph.get_backend_candidates("d") is None + assert graph.get_backend_candidates_for_node("a") is None + assert graph.get_backend_candidates_for_node("b") is None + assert graph.get_backend_candidates_for_node("c") is None + assert graph.get_backend_candidates_for_node("d") is None + + assert graph.get_backend_candidates_for_node_set(["a", "b"]) is None + assert graph.get_backend_candidates_for_node_set(["a", "b", "c"]) is None def test_get_backend_candidates_intersection(self): graph = _FrozenGraph.from_edges( [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f")], backend_candidates_map={"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}, ) - assert graph.get_backend_candidates("a") == {"b1", "b2"} - assert graph.get_backend_candidates("b") == {"b2", "b3"} - assert graph.get_backend_candidates("c") == {"b4"} - assert graph.get_backend_candidates("d") == {"b2"} - assert graph.get_backend_candidates("e") == set() - assert graph.get_backend_candidates("f") == set() + assert graph.get_backend_candidates_for_node("a") == {"b1", "b2"} + assert graph.get_backend_candidates_for_node("b") == {"b2", "b3"} + assert graph.get_backend_candidates_for_node("c") == {"b4"} + assert graph.get_backend_candidates_for_node("d") == {"b2"} + assert graph.get_backend_candidates_for_node("e") == set() + assert graph.get_backend_candidates_for_node("f") == set() + + assert graph.get_backend_candidates_for_node_set(["a", "b"]) == {"b2"} + assert graph.get_backend_candidates_for_node_set(["a", "b", "d"]) == {"b2"} + assert graph.get_backend_candidates_for_node_set(["c", "d"]) == set() def test_find_forsaken_nodes(self): graph = _FrozenGraph.from_edges( @@ -632,7 +650,7 @@ def test_find_articulation_points(self, flat, expected): def test_split_at_minimal(self): graph = _FrozenGraph.from_edges([("a", "b")], backend_candidates_map={"a": "A"}) # Split at a - down, up = graph.split_at("a") + up, down = graph.split_at("a") assert sorted(up.iter_nodes()) == [ ("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=frozenset(["A"]))), ] @@ -641,7 +659,7 @@ def test_split_at_minimal(self): ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), ] # Split at b - down, up = graph.split_at("b") + up, down = graph.split_at("b") assert sorted(up.iter_nodes()) == [ ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))), ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), @@ -652,7 +670,7 @@ def test_split_at_minimal(self): def test_split_at_basic(self): graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c")], backend_candidates_map={"a": "A"}) - down, up = graph.split_at("b") + up, down = graph.split_at("b") assert sorted(up.iter_nodes()) == [ ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))), ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), @@ -666,7 +684,7 @@ def test_split_at_complex(self): graph = _FrozenGraph.from_edges( [("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e"), ("e", "g"), ("f", "g"), ("X", "Y")] ) - down, up = graph.split_at("e") + up, down = graph.split_at("e") assert sorted(up.iter_nodes()) == sorted( _FrozenGraph.from_edges([("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e")]).iter_nodes() ) @@ -680,7 +698,7 @@ def test_split_at_non_articulation_point(self): _ = graph.split_at("b") # These should still work - down, up = graph.split_at("a") + up, down = graph.split_at("a") assert sorted(up.iter_nodes()) == [ ("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), ] @@ -690,7 +708,7 @@ def test_split_at_non_articulation_point(self): ("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)), ] - down, up = graph.split_at("c") + up, down = graph.split_at("c") assert sorted(up.iter_nodes()) == [ ("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)), ("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)), @@ -763,3 +781,94 @@ def test_produce_split_locations_merge_longer_triangle(self): [["mask1"], ["bands2"], ["lc1"], ["lc2"]], [["bands2"], ["mask1"], ["lc2"], ["lc1"]], ) + + +class TestDeepGraphSplitter: + def test_simple_no_split(self): + splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"]}) + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"lc1", "ndvi1"}, + primary_backend_id="b1", + secondary_graphs=[], + ) + + def test_simple_split(self): + """ + Most simple split use case: two load_collections from different backends, merged. + """ + splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == dirty_equals.IsOneOf( + _PGSplitResult( + primary_node_ids={"lc1", "lc2", "merge"}, + primary_backend_id="b1", + secondary_graphs=[ + _SubGraphData( + split_node="lc2", + node_ids={"lc2"}, + backend_id="b2", + ) + ], + ), + _PGSplitResult( + primary_node_ids={"lc1", "lc2", "merge"}, + primary_backend_id="b2", + secondary_graphs=[ + _SubGraphData( + split_node="lc1", + node_ids={"lc1"}, + backend_id="b1", + ) + ], + ), + ) + + def test_simple_deep_split(self): + """ + Simple deep split use case: + two load_collections from different backends, with some additional filtering, merged. + """ + splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "temporal2": { + "process_id": "filter_temporal", + "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]}, + }, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == dirty_equals.IsOneOf( + _PGSplitResult( + primary_node_ids={"bands1", "lc1", "temporal2", "merge"}, + primary_backend_id="b1", + secondary_graphs=[ + _SubGraphData(split_node="temporal2", node_ids={"lc2", "temporal2"}, backend_id="b2") + ], + ), + _PGSplitResult( + primary_node_ids={"lc2", "temporal2", "bands1", "merge"}, + primary_backend_id="b2", + secondary_graphs=[_SubGraphData(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")], + ), + )