Skip to content

Commit

Permalink
Merge pull request #95 from ecmwf/feature/grid-indices
Browse files Browse the repository at this point in the history
Support models with unconnected nodes removed from input (LAM)
  • Loading branch information
dietervdb-meteo authored Jan 8, 2025
2 parents 1187d78 + 2abd3ac commit 2385b2d
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Keep it human-readable, your future self will thank you!
## [Unreleased]

### Added
- Add support for models with unconnected nodes dropped from input [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Change trigger for boundary forcings [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Add support for automatic loading of anemoi-datasets of more general type [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Add initial state output in netcdf format
- Fix: Enable inference when no constant forcings are used
- Add anemoi-transform link to documentation
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def mars_requests(self, *, variables, dates, use_grib_paramid=False, **kwargs):

@cached_property
def _supporting_arrays(self):
return self._metadata.supporting_arrays
return self._metadata._supporting_arrays

@property
def name(self):
Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/inference/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ def __init__(self, context, input, variables, variables_mask):
self.variables_mask = variables_mask
assert isinstance(input, DatasetInput), "Currently only boundary forcings from dataset supported."
self.input = input
num_lam, num_other = input.ds.grids
self.spatial_mask = np.array([False] * num_lam + [True] * num_other, dtype=bool)
if "output_mask" in context.checkpoint._supporting_arrays:
self.spatial_mask = ~context.checkpoint.load_supporting_array("output_mask")
else:
self.spatial_mask = np.array([False] * len(input["latitudes"]), dtype=bool)
self.kinds = dict(retrieved=True) # Used for debugging

def __repr__(self):
Expand Down
24 changes: 21 additions & 3 deletions src/anemoi/inference/inputs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@ class DatasetInput(Input):

def __init__(self, context, args, kwargs):
super().__init__(context)

grid_indices = kwargs.pop("grid_indices", None)

self.args, self.kwargs = args, kwargs
if context.verbosity > 0:
LOG.info(
"Opening dataset with\nargs=%s\nkwargs=%s", json.dumps(args, indent=4), json.dumps(kwargs, indent=4)
)

if grid_indices is None and "grid_indices" in context.checkpoint._supporting_arrays:
grid_indices = context.checkpoint.load_supporting_array("grid_indices")
if context.verbosity > 0:
LOG.info(
"Loading supporting array `grid_indices` from checkpoint, \
the input grid will be reduced accordingly."
)

self.grid_indices = slice(None) if grid_indices is None else grid_indices

@cached_property
def ds(self):
from anemoi.datasets import open_dataset
Expand All @@ -48,11 +61,13 @@ def create_input_state(self, *, date=None):
raise ValueError("`date` must be provided")

date = to_datetime(date)
latitudes = self.ds.latitudes
longitudes = self.ds.longitudes

input_state = dict(
date=date,
latitudes=self.ds.latitudes,
longitudes=self.ds.longitudes,
latitudes=latitudes[self.grid_indices],
longitudes=longitudes[self.grid_indices],
fields=dict(),
)

Expand All @@ -69,7 +84,8 @@ def create_input_state(self, *, date=None):
if variable not in requested_variables:
continue
# Squeeze the data to remove the ensemble dimension
fields[variable] = np.squeeze(data[:, i], axis=1)
values = np.squeeze(data[:, i], axis=1)
fields[variable] = values[:, self.grid_indices]

return input_state

Expand All @@ -82,6 +98,8 @@ def load_forcings(self, *, variables, dates):
data = np.squeeze(data, axis=2)
# Reorder the dimensions to (variable, date, values)
data = np.swapaxes(data, 0, 1)
# apply reduction to `grid_indices`
data = data[..., self.grid_indices]
return data

def _load_dates(self, dates):
Expand Down
9 changes: 4 additions & 5 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def output_tensor_index_to_variable(self):
@cached_property
def number_of_grid_points(self):
"""Return the number of grid points per fields"""
if "grid_indices" in self._supporting_arrays:
return len(self.load_supporting_array("grid_indices"))
try:
return self._metadata.dataset.shape[-1]
except AttributeError:
Expand Down Expand Up @@ -510,14 +512,13 @@ def _find(x):
_find(y)

if isinstance(x, dict):
if "dataset" in x:
if "dataset" in x and isinstance(x["dataset"], str):
result.append(x["dataset"])

for k, v in x.items():
_find(v)

_find(self._config.dataloader.training.dataset)

return result

def open_dataset_args_kwargs(self, *, use_original_paths, from_dataloader=None):
Expand Down Expand Up @@ -717,9 +718,7 @@ def boundary_forcings_inputs(self, context, input_state):

result = []

output_mask = self._config_model.get("output_mask", None)
if output_mask is not None:
assert output_mask == "cutout", "Currently only cutout as output mask supported."
if "output_mask" in self._supporting_arrays:
result.append(
context.create_boundary_forcings(
self.prognostic_variables,
Expand Down

0 comments on commit 2385b2d

Please sign in to comment.