Skip to content

Commit

Permalink
implement grid_indices mask
Browse files Browse the repository at this point in the history
  • Loading branch information
dietervdb-meteo committed Jan 3, 2025
1 parent e2de37d commit 66ed6f1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/anemoi/inference/inputs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,26 @@ class DatasetInput(Input):

def __init__(self, context, args, kwargs):
super().__init__(context)

grid_indices=kwargs.pop("grid_indices", None)

self.args, self.kwargs = args, kwargs
if context.verbosity > 0:
LOG.info(
"Opening dataset with\nargs=%s\nkwargs=%s", json.dumps(args, indent=4), json.dumps(kwargs, indent=4)
)

if grid_indices is None and "grid_indices" in context.checkpoint._supporting_arrays:
grid_indices = context.checkpoint.load_supporting_array("grid_indices")
if context.verbosity > 0:
LOG.info(
"Loading supporting array `grid_indices` from checkpoint, \
the input grid will be reduced accordingly."
)

self.grid_indices = slice(None) if grid_indices is None else grid_indices


@cached_property
def ds(self):
from anemoi.datasets import open_dataset
Expand All @@ -48,11 +62,13 @@ def create_input_state(self, *, date=None):
raise ValueError("`date` must be provided")

date = to_datetime(date)
latitudes=self.ds.latitudes
longitudes=self.ds.longitudes

input_state = dict(
date=date,
latitudes=self.ds.latitudes,
longitudes=self.ds.longitudes,
latitudes=latitudes[self.grid_indices],
longitudes=longitudes[self.grid_indices],
fields=dict(),
)

Expand All @@ -69,7 +85,8 @@ def create_input_state(self, *, date=None):
if variable not in requested_variables:
continue
# Squeeze the data to remove the ensemble dimension
fields[variable] = np.squeeze(data[:, i], axis=1)
values = np.squeeze(data[:, i], axis=1)
fields[variable] = values[:,self.grid_indices]

return input_state

Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def output_tensor_index_to_variable(self):
@cached_property
def number_of_grid_points(self):
"""Return the number of grid points per fields"""
if "grid_indices" in self._supporting_arrays:
return len(self.load_supporting_array("grid_indices"))
try:
return self._metadata.dataset.shape[-1]
except AttributeError:
Expand Down

0 comments on commit 66ed6f1

Please sign in to comment.