Skip to content

Commit

Permalink
Merge pull request #29 from ecmwf/feature/unstructured-grid
Browse files Browse the repository at this point in the history
Feature/unstructured grid
  • Loading branch information
b8raoult authored Oct 21, 2024
2 parents d1aa46b + 5119ed5 commit e84b4c1
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Keep it human-readable, your future self will thank you!

- Fix: Enable inference when no constant forcings are used
- Add anemoi-transform link to documentation
- Add support for unstructured grids

### Changed

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ classifiers = [

dynamic = [ "version" ]
dependencies = [
"anemoi-transform>=0.0.4",
"anemoi-utils>=0.3",
"aniso8601",
"anytree",
Expand Down
10 changes: 5 additions & 5 deletions src/anemoi/inference/checkpoint/metadata/version_0_2_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,23 @@ def __init__(self, metadata):

@property
def grid(self):
return self.request["grid"]
return self.request.get("grid")

@property
def area(self):
return self.request["area"]
return self.request.get("area")

@property
def param_sfc(self):
return self.request["param_level"].get("sfc", [])
return self.request.get("param_level", {}).get("sfc", [])

@property
def param_level_pl_pairs(self):
return self.request["param_level"].get("pl", [])
return self.request.get("param_level", {}).get("pl", [])

@property
def param_level_ml_pairs(self):
return self.request["param_level"].get("ml", [])
return self.request.get("param_level", {}).get("ml", [])

@property
def param_step_sfc_pairs(self):
Expand Down
13 changes: 11 additions & 2 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def run(
output_callback=ignore,
autocast=None,
progress_callback=ignore,
grid_field_list=None,
) -> None:
"""_summary_
Expand All @@ -130,6 +131,8 @@ def run(
_description_, by default None
progress_callback : _type_, optional
_description_, by default ignore
grid_field_list: _type_, optional
_description_, by default None
Raises
------
Expand Down Expand Up @@ -334,7 +337,7 @@ def run(


constants = forcing_and_constants(
source=input_fields[:1],
source=grid_field_list if grid_field_list is not None else input_fields[:1],
param=self.checkpoint.computed_constants,
date=start_datetime,
)
Expand Down Expand Up @@ -406,7 +409,13 @@ def get_most_recent_datetime(input_fields):

most_recent_datetime = get_most_recent_datetime(input_fields)
reference_fields = [f for f in input_fields if f.datetime()["valid_time"] == most_recent_datetime]
prognostic_template = reference_fields[self.checkpoint.variable_to_index["lsm"]]

if "lsm" in self.checkpoint.variable_to_index:
prognostic_template = reference_fields[self.checkpoint.variable_to_index["lsm"]]
else:
first = list(self.checkpoint.variable_to_index.keys())
LOGGER.warning("No LSM found to use as a GRIB template, using %s", first[0])
prognostic_template = reference_fields[0]

accumulated_output = np.zeros(
shape=(len(diagnostic_output_mask), number_of_grid_points),
Expand Down

0 comments on commit e84b4c1

Please sign in to comment.