From 6b32c35455690f61b6a92f104ba3faca6e0ec534 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Sat, 31 Aug 2024 14:05:00 +0100 Subject: [PATCH] add method for input shape --- src/anemoi/inference/checkpoint/metadata/__init__.py | 9 +++++++++ .../inference/checkpoint/metadata/version_0_1_0.py | 10 ++++++++++ .../inference/checkpoint/metadata/version_0_2_0.py | 9 +++++++++ 3 files changed, 28 insertions(+) diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index 887b836..c0486f0 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -334,3 +334,12 @@ def report_loading_error(self): LOG.error("Training provenance:\n%s", json.dumps(provenance_training, indent=2)) ########################################################################### + + @property + def predict_step_shape(self): + return ( + 1, # Batch size + self.multi_step, # Lagged time steps + self.number_of_grid_points, # Grid points + self.num_input_features, # Fields + ) diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py index d3794f2..3fc5477 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py @@ -100,3 +100,13 @@ def graph(self, graph): dataset["attrs"] = dataset.copy() return ZarrRequest(dataset).graph(graph) + + @property + def number_of_grid_points(self): + from .version_0_2_0 import ZarrRequest + + dataset = self._dataset.copy() + if "attrs" not in dataset: + dataset["attrs"] = dataset.copy() + + return ZarrRequest(dataset).number_of_grid_points diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index b228822..cdeda1a 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -136,6 +136,15 @@ def dump(self, indent=0): def graph_kids(self): return [] + @property + def number_of_grid_points(self): + if "shape" in self.attributes: + return self.attributes["shape"][-1] + return { + "o96": 40_320, + "n320": 542_080, + }[self.attributes["resolution"].lower()] + class Forward(DataRequest): @cached_property