Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 3, 2025
1 parent 22d0d33 commit 2c488bb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/anemoi/inference/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def __init__(self, context, input, variables, variables_mask):
assert isinstance(input, DatasetInput), "Currently only boundary forcings from dataset supported."
self.input = input
if "output_mask" in context.checkpoint._supporting_arrays:
self.spatial_mask= ~context.checkpoint.load_supporting_array("output_mask")
self.spatial_mask = ~context.checkpoint.load_supporting_array("output_mask")
else:
self.spatial_mask=np.array([False] * len(input["latitudes"]) , dtype=bool)
self.spatial_mask = np.array([False] * len(input["latitudes"]), dtype=bool)
self.kinds = dict(retrieved=True) # Used for debugging

def __repr__(self):
Expand Down
11 changes: 5 additions & 6 deletions src/anemoi/inference/inputs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DatasetInput(Input):
def __init__(self, context, args, kwargs):
super().__init__(context)

grid_indices=kwargs.pop("grid_indices", None)
grid_indices = kwargs.pop("grid_indices", None)

self.args, self.kwargs = args, kwargs
if context.verbosity > 0:
Expand All @@ -47,7 +47,6 @@ def __init__(self, context, args, kwargs):

self.grid_indices = slice(None) if grid_indices is None else grid_indices


@cached_property
def ds(self):
from anemoi.datasets import open_dataset
Expand All @@ -62,8 +61,8 @@ def create_input_state(self, *, date=None):
raise ValueError("`date` must be provided")

date = to_datetime(date)
latitudes=self.ds.latitudes
longitudes=self.ds.longitudes
latitudes = self.ds.latitudes
longitudes = self.ds.longitudes

input_state = dict(
date=date,
Expand All @@ -86,7 +85,7 @@ def create_input_state(self, *, date=None):
continue
# Squeeze the data to remove the ensemble dimension
values = np.squeeze(data[:, i], axis=1)
fields[variable] = values[:,self.grid_indices]
fields[variable] = values[:, self.grid_indices]

return input_state

Expand All @@ -100,7 +99,7 @@ def load_forcings(self, *, variables, dates):
# Reorder the dimensions to (variable, date, values)
data = np.swapaxes(data, 0, 1)
# apply reduction to `grid_indices`
data=data[...,self.grid_indices]
data = data[..., self.grid_indices]
return data

def _load_dates(self, dates):
Expand Down

0 comments on commit 2c488bb

Please sign in to comment.