Skip to content

Commit

Permalink
Issue #150: batch of #155 review fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 20, 2024
1 parent 5a087dc commit 835b8be
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 23 deletions.
15 changes: 9 additions & 6 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,15 +808,16 @@ def create_job(
if "process_graph" not in process:
raise ProcessGraphMissingException()

# Coverage of messy "split_strategy" job option https://github.com/Open-EO/openeo-aggregator/issues/156
# TODO: better, more generic/specific job_option(s)?
# Coverage of messy "split_strategy" job option
# Also see https://github.com/Open-EO/openeo-aggregator/issues/156
# TODO: more generic and future proof handling of split strategy related options?
split_strategy = (job_options or {}).get(JOB_OPTION_SPLIT_STRATEGY)
# TODO: this job option "tile_grid" is quite generic and not very explicit about being a job splitting approach
tile_grid = (job_options or {}).get(JOB_OPTION_TILE_GRID)

crossbackend_mode = (
split_strategy == "crossbackend" or isinstance(split_strategy, dict) and "crossbackend" in split_strategy
crossbackend_mode = split_strategy == "crossbackend" or (
isinstance(split_strategy, dict) and "crossbackend" in split_strategy
)
# TODO: the legacy job option "tile_grid" is quite generic and not very explicit
# about being a job splitting approach. Can we deprecate this in a way?
spatial_split_mode = tile_grid or split_strategy == "flimsy"

if crossbackend_mode:
Expand Down Expand Up @@ -951,6 +952,7 @@ def _create_crossbackend_job(

split_strategy = (job_options or {}).get(JOB_OPTION_SPLIT_STRATEGY)
if split_strategy == "crossbackend":
# Legacy job option format
graph_split_method = CROSSBACKEND_GRAPH_SPLIT_METHOD.SIMPLE
elif isinstance(split_strategy, dict) and isinstance(split_strategy.get("crossbackend"), dict):
graph_split_method = split_strategy.get("crossbackend", {}).get(
Expand All @@ -973,6 +975,7 @@ def backend_for_collection(collection_id) -> str:
elif graph_split_method == CROSSBACKEND_GRAPH_SPLIT_METHOD.DEEP:

def supporting_backends(node_id: str, node: dict) -> Union[List[str], None]:
# TODO: wider coverage checking process id availability
if node["process_id"] == "load_collection":
collection_id = node["arguments"]["id"]
return self._catalog.get_backends_for_collection(cid=collection_id)
Expand Down
39 changes: 22 additions & 17 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,16 @@ def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId)


class _SubGraphData(NamedTuple):
"""Container for result of ProcessGraphSplitterInterface.split"""

split_node: NodeId
node_ids: Set[NodeId]
backend_id: BackendId


class _PGSplitResult(NamedTuple):
"""Container for result of ProcessGraphSplitterInterface.split"""

primary_node_ids: Set[NodeId]
primary_backend_id: BackendId
secondary_graphs: List[_SubGraphData]
Expand Down Expand Up @@ -187,10 +191,6 @@ class CrossBackendJobSplitter(AbstractJobSplitter):
"""

def __init__(self, graph_splitter: ProcessGraphSplitterInterface):
"""
:param backend_for_collection: callable that determines backend id for given collection id
:param always_split: split all load_collections, also when on same backend
"""
self._graph_splitter = graph_splitter

def split_streaming(
Expand Down Expand Up @@ -466,16 +466,13 @@ class _GVNode:
without having to worry about accidentally propagating changed state to other parts of the graph.
"""

# TODO: type coercion in __init__ of frozen dataclasses is bit ugly. Use attrs with field converters instead?

# Node ids of other nodes this node depends on (aka parents)
depends_on: frozenset[NodeId]
# Node ids of other nodes that depend on this node (aka children)
flows_to: frozenset[NodeId]

# Backend ids this node is marked to be supported on
# value None means it is unknown/unconstrained for this node
# TODO: Move this to _GraphViewer as responsibility?
backend_candidates: Union[frozenset[BackendId], None]

def __init__(
Expand All @@ -485,14 +482,19 @@ def __init__(
flows_to: Union[Iterable[NodeId], NodeId, None] = None,
backend_candidates: Union[Iterable[BackendId], BackendId, None] = None,
):
# TODO: type coercion in __init__ of frozen dataclasses is bit ugly. Use attrs with field converters instead?
super().__init__()
object.__setattr__(self, "depends_on", to_frozenset(depends_on or []))
object.__setattr__(self, "flows_to", to_frozenset(flows_to or []))
backend_candidates = to_frozenset(backend_candidates) if backend_candidates is not None else None
object.__setattr__(self, "backend_candidates", backend_candidates)

def __repr__(self):
return f"<{type(self).__name__}({self.depends_on}, {self.flows_to}, {self.backend_candidates})>"
# Somewhat cryptic, but compact representation of node attributes
depends_on = (" <" + ",".join(sorted(self.depends_on))) if self.depends_on else ""
flows_to = (" >" + ",".join(sorted(self.flows_to))) if self.flows_to else ""
backends = (" @" + ",".join(sorted(self.backend_candidates))) if self.backend_candidates else ""
return f"[{type(self).__name__}{depends_on}{flows_to}{backends}]"


class _GraphViewer:
Expand All @@ -505,7 +507,7 @@ class _GraphViewer:
def __init__(self, node_map: dict[NodeId, _GVNode]):
self._check_consistency(node_map=node_map)
# Work with a read-only proxy to prevent accidental changes
self._graph: Mapping[NodeId, _GVNode] = types.MappingProxyType(node_map)
self._graph: Mapping[NodeId, _GVNode] = types.MappingProxyType(node_map.copy())

@staticmethod
def _check_consistency(node_map: dict[NodeId, _GVNode]):
Expand Down Expand Up @@ -567,7 +569,6 @@ def from_edges(

graph = {
node_id: _GVNode(
# Note that we just use node id as process id. Do we have better options here?
depends_on=depends_on.get(node_id, []),
flows_to=flows_to.get(node_id, []),
backend_candidates=supporting_backends_mapper(node_id, {}),
Expand All @@ -593,9 +594,15 @@ def _walk(
auto_sort: bool = True,
) -> Iterator[NodeId]:
"""
Walk the graph nodes starting from given seed nodes, taking steps as defined by `next_nodes` function.
Optionally include seeds or not, and walk breadth first.
Walk the graph nodes starting from given seed nodes,
taking steps as defined by `next_nodes` function.
Walks breadth first and each node is only visited once.
:param include_seeds: whether to include the seed nodes in the walk
:param auto_sort: visit "next" nodes of a given node lexicographically sorted
to make the walk deterministic.
"""
# TODO: option to walk depth first instead of breadth first?
if auto_sort:
# Automatically sort next nodes to make walk more deterministic
prepare = sorted
Expand Down Expand Up @@ -743,7 +750,7 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]:

return up, down

def produce_split_locations(self, limit: int = 10) -> Iterator[List[NodeId]]:
def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]:
"""
Produce disjoint subgraphs that can be processed independently.
Expand All @@ -763,9 +770,8 @@ def produce_split_locations(self, limit: int = 10) -> Iterator[List[NodeId]]:

if forsaken_nodes:
# Sort forsaken nodes (based on forsaken parent count), to start higher up the graph
# TODO: avoid need for this sort, and just use a better scoring metric higher up?
forsaken_nodes = sorted(
forsaken_nodes, key=lambda n: sum(p in forsaken_nodes for p in self.node(n).depends_on)
forsaken_nodes, key=lambda n: (sum(p in forsaken_nodes for p in self.node(n).depends_on), n)
)
_log.debug(f"_GraphViewer.produce_split_locations: {forsaken_nodes=}")

Expand All @@ -784,8 +790,7 @@ def produce_split_locations(self, limit: int = 10) -> Iterator[List[NodeId]]:
_log.debug(f"_GraphViewer.produce_split_locations: {split_options=}")
if not split_options:
raise GraphSplitException("No split options found.")
# TODO: how to handle limit? will it scale feasibly to iterate over all possibilities at this point?
# TODO: smarter picking of split node (e.g. one with most upstream nodes)
# TODO: Do we really need a limit? Or is there a practical scalability risk to list all possibilities?
assert limit > 0
for split_node_id in split_options[:limit]:
_log.debug(f"_GraphViewer.produce_split_locations: splitting at {split_node_id=}")
Expand Down
23 changes: 23 additions & 0 deletions tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,14 @@ def test_eq(self):
backend_candidates=["X"],
)

def test_repr(self):
assert repr(_GVNode()) == "[_GVNode]"
assert repr(_GVNode(depends_on="a")) == "[_GVNode <a]"
assert repr(_GVNode(depends_on=["b", "a"])) == "[_GVNode <a,b]"
assert repr(_GVNode(depends_on="a", flows_to="b")) == "[_GVNode <a >b]"
assert repr(_GVNode(depends_on=["a", "b"], flows_to=["foo", "bar"])) == "[_GVNode <a,b >bar,foo]"
assert repr(_GVNode(depends_on="a", flows_to="b", backend_candidates=["x", "yy"])) == "[_GVNode <a >b @x,yy]"


def supporting_backends_from_node_id_dict(data: dict) -> SupportingBackendsMapper:
return lambda node_id, node: data.get(node_id)
Expand All @@ -501,6 +509,21 @@ def test_check_consistency(self, node_map, expected_error):
with pytest.raises(GraphSplitException, match=expected_error):
_ = _GraphViewer(node_map=node_map)

def test_immutability(self):
node_map = {"a": _GVNode(flows_to="b"), "b": _GVNode(depends_on="a")}
graph = _GraphViewer(node_map=node_map)
assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))]

# Adding a node to the original map should not affect the graph
node_map["c"] = _GVNode()
assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))]

# Trying to mess with internals shouldn't work either
with pytest.raises(Exception, match="does not support item assignment"):
graph._graph["c"] = _GVNode()

assert sorted(graph.iter_nodes()) == [("a", _GVNode(flows_to="b")), ("b", _GVNode(depends_on="a"))]

def test_from_flat_graph_basic(self):
flat = {
"lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}},
Expand Down

0 comments on commit 835b8be

Please sign in to comment.