From 22d0d336041f7f4b611a09de0d589d3740bf9446 Mon Sep 17 00:00:00 2001 From: dietervdb-meteo Date: Fri, 3 Jan 2025 16:47:13 +0200 Subject: [PATCH] adapt forcings --- src/anemoi/inference/forcings.py | 6 ++++-- src/anemoi/inference/inputs/dataset.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/anemoi/inference/forcings.py b/src/anemoi/inference/forcings.py index 984d90a..4b7d042 100644 --- a/src/anemoi/inference/forcings.py +++ b/src/anemoi/inference/forcings.py @@ -136,8 +136,10 @@ def __init__(self, context, input, variables, variables_mask): self.variables_mask = variables_mask assert isinstance(input, DatasetInput), "Currently only boundary forcings from dataset supported." self.input = input - num_lam, num_other = input.ds.grids - self.spatial_mask = np.array([False] * num_lam + [True] * num_other, dtype=bool) + if "output_mask" in context.checkpoint._supporting_arrays: + self.spatial_mask= ~context.checkpoint.load_supporting_array("output_mask") + else: + self.spatial_mask=np.array([False] * len(input["latitudes"]) , dtype=bool) self.kinds = dict(retrieved=True) # Used for debugging def __repr__(self): diff --git a/src/anemoi/inference/inputs/dataset.py b/src/anemoi/inference/inputs/dataset.py index 2696fab..e0fcfee 100644 --- a/src/anemoi/inference/inputs/dataset.py +++ b/src/anemoi/inference/inputs/dataset.py @@ -99,6 +99,8 @@ def load_forcings(self, *, variables, dates): data = np.squeeze(data, axis=2) # Reorder the dimensions to (variable, date, values) data = np.swapaxes(data, 0, 1) + # apply reduction to `grid_indices` + data=data[...,self.grid_indices] return data def _load_dates(self, dates):