diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index 887b836..c0486f0 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -334,3 +334,12 @@ def report_loading_error(self): LOG.error("Training provenance:\n%s", json.dumps(provenance_training, indent=2)) ########################################################################### + + @property + def predict_step_shape(self): + return ( + 1, # Batch size + self.multi_step, # Lagged time steps + self.number_of_grid_points, # Grid points + self.num_input_features, # Fields + ) diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py index d3794f2..3fc5477 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py @@ -100,3 +100,13 @@ def graph(self, graph): dataset["attrs"] = dataset.copy() return ZarrRequest(dataset).graph(graph) + + @property + def number_of_grid_points(self): + from .version_0_2_0 import ZarrRequest + + dataset = self._dataset.copy() + if "attrs" not in dataset: + dataset["attrs"] = dataset.copy() + + return ZarrRequest(dataset).number_of_grid_points diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index b228822..cdeda1a 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -136,6 +136,15 @@ def dump(self, indent=0): def graph_kids(self): return [] + @property + def number_of_grid_points(self): + if "shape" in self.attributes: + return self.attributes["shape"][-1] + return { + "o96": 40_320, + "n320": 542_080, + }[self.attributes["resolution"].lower()] + class Forward(DataRequest): @cached_property