Skip to content

Commit

Permalink
add method for input shape
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Aug 31, 2024
1 parent 6c7fb92 commit 6b32c35
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,12 @@ def report_loading_error(self):
LOG.error("Training provenance:\n%s", json.dumps(provenance_training, indent=2))

###########################################################################

@property
def predict_step_shape(self):
return (
1, # Batch size
self.multi_step, # Lagged time steps
self.number_of_grid_points, # Grid points
self.num_input_features, # Fields
)
10 changes: 10 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/version_0_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,13 @@ def graph(self, graph):
dataset["attrs"] = dataset.copy()

return ZarrRequest(dataset).graph(graph)

@property
def number_of_grid_points(self):
from .version_0_2_0 import ZarrRequest

dataset = self._dataset.copy()
if "attrs" not in dataset:
dataset["attrs"] = dataset.copy()

return ZarrRequest(dataset).number_of_grid_points
9 changes: 9 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/version_0_2_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def dump(self, indent=0):
def graph_kids(self):
return []

@property
def number_of_grid_points(self):
if "shape" in self.attributes:
return self.attributes["shape"][-1]
return {
"o96": 40_320,
"n320": 542_080,
}[self.attributes["resolution"].lower()]


class Forward(DataRequest):
@cached_property
Expand Down

0 comments on commit 6b32c35

Please sign in to comment.