diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index 78fb010..4e5d9b1 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -335,7 +335,7 @@ def report_loading_error(self): ########################################################################### - def digraph(self, label_maker=lambda x: {"label": x["label"]}): + def digraph(self, label_maker=lambda x: dict(label=x.kind)): import json digraph = ["digraph {"] diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index 1f4e184..e3993db 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -26,6 +26,12 @@ def __repr__(self) -> str: name = self.__class__.__name__ return name[:-7] if name.endswith("Request") else name + @property + def kind(self): + kind = self.__class__.__name__ + kind = kind[:-7] if kind.endswith("Request") else kind + return kind.lower() + def mars_request(self): def _as_list(v): @@ -84,7 +90,7 @@ def param_level_ml(self): return sorted(params), sorted(levels) def graph(self, digraph, nodes, node_maker): - nodes[f"N{id(self)}"] = node_maker(self.graph_node()) + nodes[f"N{id(self)}"] = node_maker(self) for kid in self.graph_kids(): digraph.append(f"N{id(self)} -> N{id(kid)}") kid.graph(digraph, nodes, node_maker) @@ -127,9 +133,6 @@ def variables_with_nans(self): def dump(self, indent=0): self.dump_content(indent) - def graph_node(self): - return {"label": "zarr", "metadata": {"uuid": self.attributes["uuid"]}} - def graph_kids(self): return [] @@ -146,23 +149,17 @@ def dump(self, indent=0): self.dump_content(indent) self.forward.dump(indent + 2) - def graph_node(self): - raise NotImplementedError(repr(self)) - def graph_kids(self): return [self.forward] class SubsetRequest(Forward): # Subset in time - - def graph_node(self): - return {"label": "subset"} + pass class StatisticsRequest(Forward): - def graph_node(self): - return {"label": "statistics"} + pass class RenameRequest(Forward): @@ -175,9 +172,6 @@ def variables_with_nans(self): rename = self.metadata["rename"] return sorted([rename.get(x, x) for x in self.forward.variables_with_nans]) - def graph_node(self): - return {"label": "rename"} - class MultiRequest(Forward): def __init__(self, metadata): @@ -244,21 +238,15 @@ def variables_with_nans(self): return sorted(result) - def graph_node(self): - return {"label": "join"} - class ConcatRequest(MultiRequest): # Concat in time - def graph_node(self): - return {"label": "concat"} + pass class EnsembleRequest(MultiRequest): - - def graph_node(self): - return {"label": "ensemble"} + pass class MultiGridRequest(MultiRequest): @@ -284,8 +272,7 @@ class GridsRequest(MultiGridRequest): class CutoutRequest(MultiGridRequest): - def graph_node(self): - return {"label": "cutout"} + pass class ThinningRequest(Forward): @@ -294,16 +281,10 @@ class ThinningRequest(Forward): def grid(self): return f"thinning({self.forward.grid})" - def graph_node(self): - return {"label": "thinning"} - class SelectRequest(Forward): # Select variables - def graph_node(self): - return {"label": "select", "metadata": {"variables": self.variables}} - @property def param_sfc(self): return [x for x in self.forward.param_sfc if x in self.variables] @@ -338,9 +319,6 @@ class DropRequest(SelectRequest): def variables_with_nans(self): return [x for x in self.forward.variables_with_nans if x in self.variables] - def graph_node(self): - return {"label": "drop"} - def data_request(specific): action = specific.pop("action") @@ -359,7 +337,4 @@ def __init__(self, metadata): def area(self): return self.rounded_area(self.forward.area) - def graph_node(self): - return {"label": "metadata", "metadata": {"version": "0.2.0"}} - #########################