Skip to content

Commit

Permalink
Merge branch 'develop' into feature/rename_condition_to_state
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult authored Oct 21, 2024
2 parents e01abc2 + d1aa46b commit e7bdc6f
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 37 deletions.
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ repos:
rev: v0.6.9
hooks:
- id: ruff
# Next line if for documenation cod snippets
exclude: '.*/[^_].*_\.py$'
args:
- --line-length=120
- --fix
- --exit-non-zero-on-fix
- --preview
- --exclude=docs/**/*_.py
- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v1.0.0
hooks:
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ Keep it human-readable, your future self will thank you!
## [Unreleased]

### Added
- Fix: Enable inference when no constant forcings are used

- Fix: Enable inference when no constant forcings are used
- Add anemoi-transform link to documentation

### Changed

- Add cos_solar_zenith_angle to list of known forcings
- Add missing classes in checkpoint handling
- Rename Condition to State [#24](https://github.com/ecmwf/anemoi-inference/pull/24)
- Fix pre-commit regex

### Removed

Expand Down
28 changes: 27 additions & 1 deletion src/anemoi/inference/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import datetime
import json
import logging
import os
Expand Down Expand Up @@ -206,5 +207,30 @@ def validate_environment(
raise ValueError(f"Invalid value for `on_difference`: {on_difference}")
return False

LOG.info(f"Environment validation passed")
LOG.info("Environment validation passed")
return True

def mars_requests(self, dates, use_grib_paramid=False, **kwargs):
if not isinstance(dates, (list, tuple)):
dates = [dates]

result = []

for r in self.retrieve_request(use_grib_paramid=use_grib_paramid):
for date in dates:

r = r.copy()

base = date
step = str(r.get("step", 0)).split("-")[-1]
step = int(step)
base = base - datetime.timedelta(hours=step)

r["date"] = base.strftime("%Y-%m-%d")
r["time"] = base.strftime("%H%M")

r.update(kwargs)

result.append(r)

return result
1 change: 0 additions & 1 deletion src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ def rounded_area(self, area):
return area

def report_loading_error(self):
import json

if "provenance_training" not in self._metadata:
return
Expand Down
77 changes: 46 additions & 31 deletions src/anemoi/inference/checkpoint/metadata/version_0_2_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


import logging
from collections import defaultdict
from functools import cached_property

from . import Metadata
Expand Down Expand Up @@ -154,6 +155,37 @@ def number_of_grid_points(self):
"n320": 542_080,
}[self.attributes["resolution"].lower()]

def retrieve_request(self, use_grib_paramid=False):
from anemoi.utils.grib import shortname_to_paramid
from earthkit.data.utils.availability import Availability

keys = ("class", "expver", "type", "stream", "levtype")
pop = (
"date",
"time",
)
requests = defaultdict(list)
for variable, metadata in self.attributes["variables_metadata"].items():
metadata = metadata.copy()
key = tuple(metadata.get(k) for k in keys)
for k in pop:
metadata.pop(k, None)

if use_grib_paramid and "param" in metadata:
metadata["param"] = shortname_to_paramid(metadata["param"])

requests[key].append(metadata)

for reqs in requests.values():

compressed = Availability(reqs)
for r in compressed.iterate():
for k, v in r.items():
if isinstance(v, (list, tuple)) and len(v) == 1:
r[k] = v[0]
if r:
yield r


class Forward(DataRequest):
@cached_property
Expand All @@ -171,15 +203,6 @@ def graph_kids(self):
return [self.forward]


class SubsetRequest(Forward):
# Subset in time
pass


class StatisticsRequest(Forward):
pass


class RenameRequest(Forward):

# Drop variables
Expand Down Expand Up @@ -259,16 +282,6 @@ def variables_with_nans(self):
return sorted(result)


class ConcatRequest(MultiRequest):
# Concat in time

pass


class EnsembleRequest(MultiRequest):
pass


class MultiGridRequest(MultiRequest):
@property
def grid(self):
Expand All @@ -280,7 +293,6 @@ def grid(self):
def area(self):
areas = [dataset.area for dataset in self.datasets]
return areas[0]
raise NotImplementedError(";".join(str(g) for g in areas))

def mars_request(self):
for d in self.datasets:
Expand All @@ -302,15 +314,7 @@ def grid(self):
return f"thinning({self.forward.grid})"


class InterpolatefrequencyRequest(Forward):
pass


class RescaleRequest(Forward):
pass


class ZarrwithmissingdatesRequest(ZarrRequest):
class ZarrWithMissingDatesRequest(ZarrRequest):
pass


Expand Down Expand Up @@ -354,9 +358,20 @@ def variables_with_nans(self):

def data_request(specific):
action = specific.pop("action")
action = action[0].upper() + action[1:].lower() + "Request"
action = action.capitalize() + "Request"
LOG.debug(f"DataRequest: {action}")
return globals()[action](specific)

klass = globals().get(action)

if klass is None:
if "datasets" in specific:
klass = MultiRequest
elif "forward" in specific:
klass = Forward
else:
raise ValueError(f"Unknown action: {action}")

return klass(specific)


class Version_0_2_0(Metadata, Forward):
Expand Down
50 changes: 50 additions & 0 deletions src/anemoi/inference/commands/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


from anemoi.utils.grib import shortname_to_paramid

from ..checkpoint import Checkpoint
from . import Command


class RequestCmd(Command):
"""Inspect the contents of a checkpoint file."""

need_logging = False

def add_arguments(self, command_parser):
command_parser.description = self.__doc__
command_parser.add_argument("--mars", action="store_true", help="Print the MARS request.")
command_parser.add_argument("--use-grib-paramid", action="store_true", help="Use paramId instead of param.")
command_parser.add_argument("path", help="Path to the checkpoint.")

def run(self, args):

c = Checkpoint(args.path)

for r in c.mars_requests(use_grib_paramid=args.use_grib_paramid):
if args.mars:
req = ["retrieve,target=data"]
for k, v in r.items():

if args.use_grib_paramid and k == "param":
if not isinstance(v, (list, tuple)):
v = [v]
v = [shortname_to_paramid(x) for x in v]

if isinstance(v, (list, tuple)):
v = "/".join([str(x) for x in v])
req.append(f"{k}={v}")
r = ",".join(req)
print(r)


command = RequestCmd
77 changes: 77 additions & 0 deletions src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


import datetime
import logging

from anemoi.utils.dates import as_datetime

from ..runner import DefaultRunner
from . import Command

LOGGER = logging.getLogger(__name__)


class RunCmd(Command):
"""Inspect the contents of a checkpoint file."""

need_logging = False

def add_arguments(self, command_parser):
command_parser.description = self.__doc__
command_parser.add_argument("--use-grib-paramid", action="store_true", help="Use paramId instead of param.")
command_parser.add_argument("--date", help="Date to use for the request.")
command_parser.add_argument("path", help="Path to the checkpoint.")

def run(self, args):
import earthkit.data as ekd

runner = DefaultRunner(args.path)

date = as_datetime(args.date)
dates = [date + datetime.timedelta(hours=h) for h in runner.lagged]

print("------------------------------------")
for n in runner.checkpoint.mars_requests(
dates=dates[0],
expver="0001",
use_grib_paramid=False,
):
print("MARS", n)
print("------------------------------------")

requests = runner.checkpoint.mars_requests(
dates=dates,
expver="0001",
use_grib_paramid=args.use_grib_paramid,
)

input_fields = ekd.from_source("empty")
for r in requests:
if r["class"] == "rd":
r["class"] = "od"

r["grid"] = runner.checkpoint.grid
r["area"] = runner.checkpoint.area

print("MARS", r)

input_fields += ekd.from_source("mars", r)

LOGGER.info("Running the model with the following %s fields, for %s dates", len(input_fields), len(dates))

run = runner.make_runner(input_fields=input_fields, lead_time=240, device="cuda")
run.run()

runner.run(input_fields=input_fields, lead_time=244, device="cuda")


command = RunCmd
Loading

0 comments on commit e7bdc6f

Please sign in to comment.