From 2cc7c86f88a8af7c02f2db2dab2c37be1f6ee3d0 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Wed, 18 Sep 2024 10:24:39 +0200 Subject: [PATCH] Issue #150 improve DeepGraphSplitter test coverage also make graph walking more deterministic (e.g. to simplify test asserts) --- .../partitionedjobs/crossbackend.py | 26 ++- tests/partitionedjobs/test_crossbackend.py | 214 ++++++++++++------ 2 files changed, 168 insertions(+), 72 deletions(-) diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 9af87e2..984917f 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -453,6 +453,7 @@ class _FrozenNode: # TODO: instead of frozen dataclass: have __init__ with some type casting/validation. Or use attrs? # TODO: better name for this class? + # TODO: use NamedTuple instead of dataclass? # Node ids of other nodes this node depends on (aka parents) depends_on: frozenset[NodeId] @@ -560,18 +561,28 @@ def iter_nodes(self) -> Iterator[Tuple[NodeId, _FrozenNode]]: yield from self._graph.items() def _walk( - self, seeds: Iterable[NodeId], next_nodes: Callable[[NodeId], Iterable[NodeId]], include_seeds: bool = True + self, + seeds: Iterable[NodeId], + next_nodes: Callable[[NodeId], Iterable[NodeId]], + include_seeds: bool = True, + auto_sort: bool = True, ) -> Iterator[NodeId]: """ Walk the graph nodes starting from given seed nodes, taking steps as defined by `next_nodes` function. Optionally include seeds or not, and walk breadth first. """ + if auto_sort: + # Automatically sort next nodes to make walk more deterministic + prepare = sorted + else: + prepare = lambda x: x + if include_seeds: visited = set() - to_visit = list(seeds) + to_visit = list(prepare(seeds)) else: visited = set(seeds) - to_visit = [n for s in seeds for n in next_nodes(s)] + to_visit = [n for s in seeds for n in prepare(next_nodes(s))] while to_visit: node_id = to_visit.pop(0) @@ -579,7 +590,7 @@ def _walk( continue yield node_id visited.add(node_id) - to_visit.extend(set(next_nodes(node_id)).difference(visited)) + to_visit.extend(prepare(set(next_nodes(node_id)).difference(visited))) def walk_upstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]: """ @@ -728,7 +739,7 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]: # Sort forsaken nodes (based on forsaken parent count), to start higher up the graph # TODO: avoid need for this sort, and just use a better scoring metric higher up? forsaken_nodes = sorted( - forsaken_nodes, reverse=True, key=lambda n: sum(p in forsaken_nodes for p in self.node(n).depends_on) + forsaken_nodes, key=lambda n: sum(p in forsaken_nodes for p in self.node(n).depends_on) ) # Collect nodes where we could split the graph in disjoint subgraphs articulation_points: Set[NodeId] = set(self.find_articulation_points()) @@ -745,16 +756,17 @@ def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]: raise GraphSplitException("No split options found.") # TODO: how to handle limit? will it scale feasibly to iterate over all possibilities at this point? # TODO: smarter picking of split node (e.g. one with most upstream nodes) + assert limit > 0 for split_node_id in split_options[:limit]: # Split graph at this articulation point up, down = self.split_at(split_node_id) if down.find_forsaken_nodes(): - down_splits = list(down.produce_split_locations(limit=limit - 1)) + down_splits = list(down.produce_split_locations(limit=max(limit - 1, 1))) else: down_splits = [[]] if up.find_forsaken_nodes(): # TODO: will this actually happen? the upstream sub-graph should be single-backend by design? - up_splits = list(up.produce_split_locations(limit=limit - 1)) + up_splits = list(up.produce_split_locations(limit=max(limit - 1, 1))) else: up_splits = [[]] diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index f1e774f..66218e4 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional from unittest import mock -import dirty_equals import openeo import pytest import requests @@ -480,38 +479,38 @@ def test_from_edges(self): (["c"], False, ["a"]), (["a", "c"], True, ["a", "c"]), (["a", "c"], False, []), - (["c", "a"], True, ["c", "a"]), + (["c", "a"], True, ["a", "c"]), (["c", "a"], False, []), - ( - ["e"], - True, - dirty_equals.IsOneOf( - ["e", "c", "d", "a", "b"], - ["e", "d", "c", "b", "a"], - ), - ), - ( - ["e"], - False, - dirty_equals.IsOneOf( - ["c", "d", "a", "b"], - ["d", "c", "b", "a"], - ), - ), - (["e", "d"], True, ["e", "d", "c", "b", "a"]), + (["e"], True, ["e", "c", "d", "a", "b"]), + (["e"], False, ["c", "d", "a", "b"]), + (["e", "d"], True, ["d", "e", "b", "c", "a"]), (["e", "d"], False, ["c", "b", "a"]), (["d", "e"], True, ["d", "e", "b", "c", "a"]), (["d", "e"], False, ["b", "c", "a"]), - (["f", "c"], True, ["f", "c", "e", "a", "d", "b"]), + (["f", "c"], True, ["c", "f", "a", "e", "d", "b"]), (["f", "c"], False, ["e", "a", "d", "b"]), ], ) def test_walk_upstream_nodes(self, seed, include_seeds, expected): - graph = _FrozenGraph.from_edges([("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")]) + graph = _FrozenGraph.from_edges( + # a b + # | | + # c d + # \ / + # e + # | + # f + [("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")] + ) assert list(graph.walk_upstream_nodes(seed, include_seeds)) == expected def test_get_backend_candidates_basic(self): graph = _FrozenGraph.from_edges( + # a + # | + # b c + # \ / + # d [("a", "b"), ("b", "d"), ("c", "d")], backend_candidates_map={"a": ["b1"], "c": ["b2"]}, ) @@ -530,6 +529,11 @@ def test_get_backend_candidates_basic(self): def test_get_backend_candidates_none(self): graph = _FrozenGraph.from_edges( + # a + # | + # b c + # \ / + # d [("a", "b"), ("b", "d"), ("c", "d")], backend_candidates_map={}, ) @@ -543,6 +547,11 @@ def test_get_backend_candidates_none(self): def test_get_backend_candidates_intersection(self): graph = _FrozenGraph.from_edges( + # a b c + # \ / \ / + # d e + # \ / + # f [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f")], backend_candidates_map={"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}, ) @@ -559,6 +568,13 @@ def test_get_backend_candidates_intersection(self): def test_find_forsaken_nodes(self): graph = _FrozenGraph.from_edges( + # a b c + # \ / \ / + # d e + # \ / + # f + # / \ + # g h [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f"), ("f", "g"), ("f", "h")], backend_candidates_map={"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}, ) @@ -721,6 +737,9 @@ def test_split_at_non_articulation_point(self): def test_produce_split_locations_simple(self): """Simple produce_split_locations use case: no need for splits""" flat = { + # lc1 + # | + # ndvi1 "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, } @@ -733,6 +752,9 @@ def test_produce_split_locations_merge_basic(self): two load collections on different backends and a merge """ flat = { + # lc1 lc2 + # \ / + # merge1 "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "merge1": { @@ -745,6 +767,11 @@ def test_produce_split_locations_merge_basic(self): def test_produce_split_locations_merge_longer(self): flat = { + # lc1 lc2 + # | | + # bands1 bands2 + # \ / + # merge1 "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, @@ -756,13 +783,17 @@ def test_produce_split_locations_merge_longer(self): } graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) assert sorted(graph.produce_split_locations(limit=2)) == [["bands1"], ["bands2"]] - assert list(graph.produce_split_locations(limit=4)) == dirty_equals.IsOneOf( - [["bands1"], ["bands2"], ["lc1"], ["lc2"]], - [["bands2"], ["bands1"], ["lc2"], ["lc1"]], - ) + assert list(graph.produce_split_locations(limit=4)) == [["bands1"], ["bands2"], ["lc1"], ["lc2"]] def test_produce_split_locations_merge_longer_triangle(self): flat = { + # lc1 + # / | + # bands1 | lc2 + # \ | | + # mask1 bands2 + # \ / + # merge1 "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, "mask1": { @@ -777,10 +808,7 @@ def test_produce_split_locations_merge_longer_triangle(self): }, } graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) - assert list(graph.produce_split_locations(limit=4)) == dirty_equals.IsOneOf( - [["mask1"], ["bands2"], ["lc1"], ["lc2"]], - [["bands2"], ["mask1"], ["lc2"], ["lc1"]], - ) + assert list(graph.produce_split_locations(limit=4)) == [["bands2"], ["mask1"], ["lc2"], ["lc1"]] class TestDeepGraphSplitter: @@ -803,6 +831,9 @@ def test_simple_split(self): """ splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) flat = { + # lc1 lc2 + # \ / + # merge "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "merge": { @@ -812,29 +843,16 @@ def test_simple_split(self): }, } result = splitter.split(flat) - assert result == dirty_equals.IsOneOf( - _PGSplitResult( - primary_node_ids={"lc1", "lc2", "merge"}, - primary_backend_id="b1", - secondary_graphs=[ - _SubGraphData( - split_node="lc2", - node_ids={"lc2"}, - backend_id="b2", - ) - ], - ), - _PGSplitResult( - primary_node_ids={"lc1", "lc2", "merge"}, - primary_backend_id="b2", - secondary_graphs=[ - _SubGraphData( - split_node="lc1", - node_ids={"lc1"}, - backend_id="b1", - ) - ], - ), + assert result == _PGSplitResult( + primary_node_ids={"lc1", "lc2", "merge"}, + primary_backend_id="b2", + secondary_graphs=[ + _SubGraphData( + split_node="lc1", + node_ids={"lc1"}, + backend_id="b1", + ) + ], ) def test_simple_deep_split(self): @@ -844,6 +862,11 @@ def test_simple_deep_split(self): """ splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) flat = { + # lc1 lc2 + # | | + # bands1 temporal2 + # \ / + # merge "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, @@ -858,17 +881,78 @@ def test_simple_deep_split(self): }, } result = splitter.split(flat) - assert result == dirty_equals.IsOneOf( - _PGSplitResult( - primary_node_ids={"bands1", "lc1", "temporal2", "merge"}, - primary_backend_id="b1", - secondary_graphs=[ - _SubGraphData(split_node="temporal2", node_ids={"lc2", "temporal2"}, backend_id="b2") - ], - ), - _PGSplitResult( - primary_node_ids={"lc2", "temporal2", "bands1", "merge"}, - primary_backend_id="b2", - secondary_graphs=[_SubGraphData(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")], - ), + assert result == _PGSplitResult( + primary_node_ids={"lc2", "temporal2", "bands1", "merge"}, + primary_backend_id="b2", + secondary_graphs=[_SubGraphData(split_node="bands1", node_ids={"lc1", "bands1"}, backend_id="b1")], + ) + + def test_shallow_triple_split(self): + splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]}) + flat = { + # lc1 lc2 lc3 + # \ / / + # merge1 / + # \ / + # merge2 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc3": {"process_id": "load_collection", "arguments": {"id": "S3"}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + }, + "merge2": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "merge1"}, "cube2": {"from_node": "lc3"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"lc1", "lc2", "lc3", "merge1", "merge2"}, + primary_backend_id="b2", + secondary_graphs=[ + _SubGraphData(split_node="lc1", node_ids={"lc1"}, backend_id="b1"), + _SubGraphData(split_node="lc3", node_ids={"lc3"}, backend_id="b3"), + ], + ) + + def test_triple_split(self): + splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]}) + flat = { + # lc1 lc2 lc3 + # | | | + # bands1 temporal2 spatial3 + # \ / / + # merge1 / + # \ / + # merge2 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc3": {"process_id": "load_collection", "arguments": {"id": "S3"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "temporal2": { + "process_id": "filter_temporal", + "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]}, + }, + "spatial3": {"process_id": "filter_spatial", "arguments": {"data": {"from_node": "lc3"}, "extent": "EU"}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + }, + "merge2": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "merge1"}, "cube2": {"from_node": "spatial3"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids={"merge2", "merge1", "lc3", "spatial3"}, + primary_backend_id="b3", + secondary_graphs=[ + _SubGraphData(split_node="bands1", node_ids={"bands1", "lc1"}, backend_id="b1"), + _SubGraphData(split_node="merge1", node_ids={"bands1", "merge1", "temporal2", "lc2"}, backend_id="b2"), + ], )