From 61fd94f2b840ec52689e63a3edd7cbff90c26545 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 10 Oct 2024 16:13:33 +0100 Subject: [PATCH 1/8] use new metadata --- .../checkpoint/metadata/version_0_2_0.py | 75 +++++++++++-------- src/anemoi/inference/commands/request.py | 32 ++++++++ 2 files changed, 76 insertions(+), 31 deletions(-) create mode 100644 src/anemoi/inference/commands/request.py diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index 7601238..79ee6c8 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -7,6 +7,7 @@ import logging +from collections import defaultdict from functools import cached_property from . import Metadata @@ -154,6 +155,35 @@ def number_of_grid_points(self): "n320": 542_080, }[self.attributes["resolution"].lower()] + def retrieve_request(self): + from earthkit.data.utils.availability import Availability + + keys = ("type", "stream", "levtype") + pop = ( + "date", + "time", + "class", + "expver", + ) + requests = defaultdict(list) + for variable, metadata in self.attributes["variables_metadata"].items(): + metadata = metadata.copy() + key = tuple(metadata.get(k) for k in keys) + for k in pop: + metadata.pop(k, None) + + requests[key].append(metadata) + + for reqs in requests.values(): + + compressed = Availability(reqs) + for r in compressed.iterate(): + for k, v in r.items(): + if isinstance(v, (list, tuple)) and len(v) == 1: + r[k] = v[0] + if r: + yield r + class Forward(DataRequest): @cached_property @@ -171,15 +201,6 @@ def graph_kids(self): return [self.forward] -class SubsetRequest(Forward): - # Subset in time - pass - - -class StatisticsRequest(Forward): - pass - - class RenameRequest(Forward): # Drop variables @@ -259,16 +280,6 @@ def variables_with_nans(self): return sorted(result) -class ConcatRequest(MultiRequest): - # Concat in time - - pass - - -class EnsembleRequest(MultiRequest): - pass - - class MultiGridRequest(MultiRequest): @property def grid(self): @@ -280,7 +291,6 @@ def grid(self): def area(self): areas = [dataset.area for dataset in self.datasets] return areas[0] - raise NotImplementedError(";".join(str(g) for g in areas)) def mars_request(self): for d in self.datasets: @@ -302,15 +312,7 @@ def grid(self): return f"thinning({self.forward.grid})" -class InterpolatefrequencyRequest(Forward): - pass - - -class RescaleRequest(Forward): - pass - - -class ZarrwithmissingdatesRequest(ZarrRequest): +class ZarrWithMissingDatesRequest(ZarrRequest): pass @@ -354,9 +356,20 @@ def variables_with_nans(self): def data_request(specific): action = specific.pop("action") - action = action[0].upper() + action[1:].lower() + "Request" + action = action.capitalize() + "Request" LOG.debug(f"DataRequest: {action}") - return globals()[action](specific) + + klass = globals().get(action) + + if klass is None: + if "datasets" in specific: + klass = MultiRequest + elif "forward" in specific: + klass = Forward + else: + raise ValueError(f"Unknown action: {action}") + + return klass(specific) class Version_0_2_0(Metadata, Forward): diff --git a/src/anemoi/inference/commands/request.py b/src/anemoi/inference/commands/request.py new file mode 100644 index 0000000..8959462 --- /dev/null +++ b/src/anemoi/inference/commands/request.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +from ..checkpoint import Checkpoint +from . import Command + + +class RequestCmd(Command): + """Inspect the contents of a checkpoint file.""" + + need_logging = False + + def add_arguments(self, command_parser): + command_parser.add_argument("path", help="Path to the checkpoint.") + + def run(self, args): + + c = Checkpoint(args.path) + + for r in c.retrieve_request(): + print(r) + + +command = RequestCmd From a3f9052e47a982c325b4cb66e147765d2bf92e69 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 15 Oct 2024 18:39:53 +0100 Subject: [PATCH 2/8] generate mars request --- src/anemoi/inference/commands/request.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/anemoi/inference/commands/request.py b/src/anemoi/inference/commands/request.py index 8959462..0e01800 100644 --- a/src/anemoi/inference/commands/request.py +++ b/src/anemoi/inference/commands/request.py @@ -9,6 +9,8 @@ # +from anemoi.utils.grib import shortname_to_paramid + from ..checkpoint import Checkpoint from . import Command @@ -19,6 +21,9 @@ class RequestCmd(Command): need_logging = False def add_arguments(self, command_parser): + command_parser.description = self.__doc__ + command_parser.add_argument("--mars", action="store_true", help="Print the MARS request.") + command_parser.add_argument("--use-paramid", action="store_true", help="Use paramId instead of param.") command_parser.add_argument("path", help="Path to the checkpoint.") def run(self, args): @@ -26,6 +31,19 @@ def run(self, args): c = Checkpoint(args.path) for r in c.retrieve_request(): + if args.mars: + req = ["retrieve,target=data"] + for k, v in r.items(): + + if args.use_paramid and k == "param": + if not isinstance(v, (list, tuple)): + v = [v] + v = [shortname_to_paramid(x) for x in v] + + if isinstance(v, (list, tuple)): + v = "/".join([str(x) for x in v]) + req.append(f"{k}={v}") + r = ",".join(req) print(r) From 7839c646a61582e71cb4a154f0f5446476635a00 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 15 Oct 2024 20:40:06 +0000 Subject: [PATCH 3/8] add run command --- src/anemoi/inference/checkpoint/__init__.py | 26 +++++++++ .../checkpoint/metadata/version_0_2_0.py | 10 ++-- src/anemoi/inference/commands/request.py | 2 +- src/anemoi/inference/commands/run.py | 54 +++++++++++++++++++ 4 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 src/anemoi/inference/commands/run.py diff --git a/src/anemoi/inference/checkpoint/__init__.py b/src/anemoi/inference/checkpoint/__init__.py index f1e54cf..be4c4a8 100644 --- a/src/anemoi/inference/checkpoint/__init__.py +++ b/src/anemoi/inference/checkpoint/__init__.py @@ -7,6 +7,7 @@ from __future__ import annotations +import datetime import json import logging import os @@ -208,3 +209,28 @@ def validate_environment( LOG.info(f"Environment validation passed") return True + + def mars_requests(self, dates, use_paramid=False, **kwargs): + if not isinstance(dates, (list, tuple)): + dates = [dates] + + result = [] + + for r in self.retrieve_request(use_paramid=use_paramid): + for date in dates: + + r = r.copy() + + base = date + step = str(r.get("step", 0)).split("-")[-1] + step = int(step) + base = base - datetime.timedelta(hours=step) + + r["date"] = base.strftime("%Y-%m-%d") + r["time"] = base.strftime("%H%M") + + r.update(kwargs) + + result.append(r) + + return result diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index 79ee6c8..7616fbe 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -155,15 +155,14 @@ def number_of_grid_points(self): "n320": 542_080, }[self.attributes["resolution"].lower()] - def retrieve_request(self): + def retrieve_request(self, use_paramid=False): + from anemoi.utils.grib import shortname_to_paramid from earthkit.data.utils.availability import Availability - keys = ("type", "stream", "levtype") + keys = ("class", "expver", "type", "stream", "levtype") pop = ( "date", "time", - "class", - "expver", ) requests = defaultdict(list) for variable, metadata in self.attributes["variables_metadata"].items(): @@ -172,6 +171,9 @@ def retrieve_request(self): for k in pop: metadata.pop(k, None) + if use_paramid and "param" in metadata: + metadata["param"] = shortname_to_paramid(metadata["param"]) + requests[key].append(metadata) for reqs in requests.values(): diff --git a/src/anemoi/inference/commands/request.py b/src/anemoi/inference/commands/request.py index 0e01800..5585e5d 100644 --- a/src/anemoi/inference/commands/request.py +++ b/src/anemoi/inference/commands/request.py @@ -30,7 +30,7 @@ def run(self, args): c = Checkpoint(args.path) - for r in c.retrieve_request(): + for r in c.mars_requests(use_paramid=args.use_paramid): if args.mars: req = ["retrieve,target=data"] for k, v in r.items(): diff --git a/src/anemoi/inference/commands/run.py b/src/anemoi/inference/commands/run.py new file mode 100644 index 0000000..2028239 --- /dev/null +++ b/src/anemoi/inference/commands/run.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import datetime + +from anemoi.utils.dates import as_datetime + +from ..runner import DefaultRunner +from . import Command + + +class RunCmd(Command): + """Inspect the contents of a checkpoint file.""" + + need_logging = False + + def add_arguments(self, command_parser): + command_parser.description = self.__doc__ + command_parser.add_argument("--use-paramid", action="store_true", help="Use paramId instead of param.") + command_parser.add_argument("--date", help="Date to use for the request.") + command_parser.add_argument("path", help="Path to the checkpoint.") + + def run(self, args): + import earthkit.data as ekd + + runner = DefaultRunner(args.path) + + date = as_datetime(args.date) + dates = [date + datetime.timedelta(hours=h) for h in runner.lagged] + + requests = runner.checkpoint.mars_requests(dates=dates, expver="0001", use_paramid=args.use_paramid) + + input_fields = ekd.from_source("empty") + for r in requests: + if r["class"] == "rd": + r["class"] = "od" + + r["grid"] = runner.checkpoint.grid + r["area"] = runner.checkpoint.area + + input_fields += ekd.from_source("mars", r) + + runner.run(input_fields=input_fields, lead_time=240, device="cuda") + + +command = RunCmd From 9700121a90f9009ae0ccf3a0ac6f4149dab527a3 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 15 Oct 2024 20:41:32 +0000 Subject: [PATCH 4/8] add run command --- src/anemoi/inference/commands/run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/anemoi/inference/commands/run.py b/src/anemoi/inference/commands/run.py index 2028239..f5e7fd7 100644 --- a/src/anemoi/inference/commands/run.py +++ b/src/anemoi/inference/commands/run.py @@ -36,7 +36,11 @@ def run(self, args): date = as_datetime(args.date) dates = [date + datetime.timedelta(hours=h) for h in runner.lagged] - requests = runner.checkpoint.mars_requests(dates=dates, expver="0001", use_paramid=args.use_paramid) + requests = runner.checkpoint.mars_requests( + dates=dates, + expver="0001", + use_paramid=args.use_paramid, + ) input_fields = ekd.from_source("empty") for r in requests: From 7c0e8168a4cb68c2e3c62cc5f2873f7e4e829b19 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 15 Oct 2024 21:32:09 +0000 Subject: [PATCH 5/8] add debug info --- src/anemoi/inference/commands/run.py | 5 ++ src/anemoi/inference/runner.py | 76 ++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/src/anemoi/inference/commands/run.py b/src/anemoi/inference/commands/run.py index f5e7fd7..61189d7 100644 --- a/src/anemoi/inference/commands/run.py +++ b/src/anemoi/inference/commands/run.py @@ -10,12 +10,15 @@ import datetime +import logging from anemoi.utils.dates import as_datetime from ..runner import DefaultRunner from . import Command +LOGGER = logging.getLogger(__name__) + class RunCmd(Command): """Inspect the contents of a checkpoint file.""" @@ -52,6 +55,8 @@ def run(self, args): input_fields += ekd.from_source("mars", r) + LOGGER.info("Running the model with the following %s fields, for %s dates", len(input_fields), len(dates)) + runner.run(input_fields=input_fields, lead_time=240, device="cuda") diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 84652d9..b32aa0c 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -54,6 +54,33 @@ def ignore(*args, **kwargs): pass +def _fix_eccodes_bug_for_levtype_sfc_and_grib2(input_fields): + from earthkit.data.indexing.fieldlist import FieldArray + + class BugFix: + def __init__(self, field): + self.field = field + + def __getattr__(self, name): + return getattr(self.field, name) + + def metadata(self, *args, **kwargs): + if len(args) > 0 and args[0] == "levelist": + return None + return self.field.metadata(*args, **kwargs) + + fixed = [] + for field in input_fields: + if field.metadata("levtype") in ("sfc", "o2d") and field.metadata("edition") == 2: + if field.metadata("levelist", default=None) is not None: + # LOGGER.warning("Fixing eccodes bug for levtype=sfc and grib2 %s", field) + fixed.append(BugFix(field)) + else: + fixed.append(field) + + return FieldArray(fixed) + + class Runner: """_summary_""" @@ -113,7 +140,39 @@ def run( autocast = AUTOCAST[autocast] + params_before = set(f.metadata("param") for f in input_fields) + + input_fields = _fix_eccodes_bug_for_levtype_sfc_and_grib2(input_fields) + + before = {id(x): x for x in input_fields} + LOGGER.info("Selecting fields %s", self.checkpoint.select) input_fields = input_fields.sel(**self.checkpoint.select) + LOGGER.info("Selected fields: %s", len(input_fields)) + + params_after = set(f.metadata("param") for f in input_fields) + + after = {} + for x in input_fields: + if hasattr(x, "field"): + after[id(x.field)] = x + else: + after[id(x)] = x + + LOGGER.info("Input fields before/after %s %s", params_before - params_after, params_after - params_before) + + if len(before) != len(after): + LOGGER.error("Input fields before=%s after=%s", len(before), len(after)) + param_level = self.checkpoint.select["param_level"] + for i, f in before.items(): + if i not in after: + if f.metadata("param") in param_level: + LOGGER.error( + "Field %s %s grib%s missing", + f.metadata("param"), + f.metadata("levelist", default="?"), + f.metadata("edition", default="?"), + ) + input_fields = input_fields.order_by(**self.checkpoint.order_by) number_of_grid_points = len(input_fields[0].grid_points()[0]) @@ -149,6 +208,8 @@ def run( number_of_grid_points, ) # nlags, nparams, ngrid + LOGGER.info("Input fields shape: %s (dates, variables, grid)", input_fields_numpy.shape) + # Used to check if we cover the whole input, with no overlaps check = np.full(self.checkpoint.num_input_features, fill_value=False, dtype=np.bool_) @@ -238,6 +299,21 @@ def run( # Check that the computed constant mask and the constant from input mask are disjoint # assert np.amax(prognostic_input_mask) < np.amin(constant_from_input_mask) + if len(prognostic_input_mask) != len(prognostic_data_from_retrieved_fields_mask): + LOGGER.error("Mismatch in prognostic_input_mask and prognostic_data_from_retrieved_fields_mask") + LOGGER.error("prognostic_input_mask: %s", prognostic_input_mask) + LOGGER.error("prognostic_data_from_retrieved_fields_mask: %s", prognostic_data_from_retrieved_fields_mask) + raise ValueError("Mismatch in prognostic_input_mask and prognostic_data_from_retrieved_fields_mask") + + if len(prognostic_data_from_retrieved_fields_mask) != input_fields_numpy.shape[1]: + LOGGER.error("Mismatch in prognostic_data_from_retrieved_fields_mask and input_fields_numpy") + LOGGER.error( + "prognostic_data_from_retrieved_fields_mask: %s shape=(variables)", + prognostic_data_from_retrieved_fields_mask.shape, + ) + LOGGER.error("input_fields_numpy: %s shape=(dates, variables, grid)", input_fields_numpy.shape) + raise ValueError("Mismatch in prognostic_input_mask and input_fields_numpy") + input_tensor_numpy[:, prognostic_input_mask] = input_fields_numpy[:, prognostic_data_from_retrieved_fields_mask] input_tensor_numpy[:, constant_from_input_mask] = input_fields_numpy[ From 6e7f026bc69072feb5b37bb626b2e1bce5d892e7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 16 Oct 2024 12:44:32 +0000 Subject: [PATCH 6/8] update --- src/anemoi/inference/checkpoint/__init__.py | 4 +- .../checkpoint/metadata/version_0_2_0.py | 4 +- src/anemoi/inference/commands/request.py | 6 +- src/anemoi/inference/commands/run.py | 20 +++- src/anemoi/inference/runner.py | 107 +++++++++++------- 5 files changed, 92 insertions(+), 49 deletions(-) diff --git a/src/anemoi/inference/checkpoint/__init__.py b/src/anemoi/inference/checkpoint/__init__.py index be4c4a8..754778e 100644 --- a/src/anemoi/inference/checkpoint/__init__.py +++ b/src/anemoi/inference/checkpoint/__init__.py @@ -210,13 +210,13 @@ def validate_environment( LOG.info(f"Environment validation passed") return True - def mars_requests(self, dates, use_paramid=False, **kwargs): + def mars_requests(self, dates, use_grib_paramid=False, **kwargs): if not isinstance(dates, (list, tuple)): dates = [dates] result = [] - for r in self.retrieve_request(use_paramid=use_paramid): + for r in self.retrieve_request(use_grib_paramid=use_grib_paramid): for date in dates: r = r.copy() diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index 7616fbe..e1cb39c 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -155,7 +155,7 @@ def number_of_grid_points(self): "n320": 542_080, }[self.attributes["resolution"].lower()] - def retrieve_request(self, use_paramid=False): + def retrieve_request(self, use_grib_paramid=False): from anemoi.utils.grib import shortname_to_paramid from earthkit.data.utils.availability import Availability @@ -171,7 +171,7 @@ def retrieve_request(self, use_paramid=False): for k in pop: metadata.pop(k, None) - if use_paramid and "param" in metadata: + if use_grib_paramid and "param" in metadata: metadata["param"] = shortname_to_paramid(metadata["param"]) requests[key].append(metadata) diff --git a/src/anemoi/inference/commands/request.py b/src/anemoi/inference/commands/request.py index 5585e5d..5721a7d 100644 --- a/src/anemoi/inference/commands/request.py +++ b/src/anemoi/inference/commands/request.py @@ -23,19 +23,19 @@ class RequestCmd(Command): def add_arguments(self, command_parser): command_parser.description = self.__doc__ command_parser.add_argument("--mars", action="store_true", help="Print the MARS request.") - command_parser.add_argument("--use-paramid", action="store_true", help="Use paramId instead of param.") + command_parser.add_argument("--use-grib-paramid", action="store_true", help="Use paramId instead of param.") command_parser.add_argument("path", help="Path to the checkpoint.") def run(self, args): c = Checkpoint(args.path) - for r in c.mars_requests(use_paramid=args.use_paramid): + for r in c.mars_requests(use_grib_paramid=args.use_grib_paramid): if args.mars: req = ["retrieve,target=data"] for k, v in r.items(): - if args.use_paramid and k == "param": + if args.use_grib_paramid and k == "param": if not isinstance(v, (list, tuple)): v = [v] v = [shortname_to_paramid(x) for x in v] diff --git a/src/anemoi/inference/commands/run.py b/src/anemoi/inference/commands/run.py index 61189d7..4a6e435 100644 --- a/src/anemoi/inference/commands/run.py +++ b/src/anemoi/inference/commands/run.py @@ -27,7 +27,7 @@ class RunCmd(Command): def add_arguments(self, command_parser): command_parser.description = self.__doc__ - command_parser.add_argument("--use-paramid", action="store_true", help="Use paramId instead of param.") + command_parser.add_argument("--use-grib-paramid", action="store_true", help="Use paramId instead of param.") command_parser.add_argument("--date", help="Date to use for the request.") command_parser.add_argument("path", help="Path to the checkpoint.") @@ -39,10 +39,19 @@ def run(self, args): date = as_datetime(args.date) dates = [date + datetime.timedelta(hours=h) for h in runner.lagged] + print("------------------------------------") + for n in runner.checkpoint.mars_requests( + dates=dates[0], + expver="0001", + use_grib_paramid=False, + ): + print("MARS", n) + print("------------------------------------") + requests = runner.checkpoint.mars_requests( dates=dates, expver="0001", - use_paramid=args.use_paramid, + use_grib_paramid=args.use_grib_paramid, ) input_fields = ekd.from_source("empty") @@ -53,11 +62,16 @@ def run(self, args): r["grid"] = runner.checkpoint.grid r["area"] = runner.checkpoint.area + print("MARS", r) + input_fields += ekd.from_source("mars", r) LOGGER.info("Running the model with the following %s fields, for %s dates", len(input_fields), len(dates)) - runner.run(input_fields=input_fields, lead_time=240, device="cuda") + run = runner.make_runner(input_fields=input_fields, lead_time=240, device="cuda") + run.run() + + runner.run(input_fields=input_fields, lead_time=244, device="cuda") command = RunCmd diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index b32aa0c..82bfdc0 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -64,11 +64,22 @@ def __init__(self, field): def __getattr__(self, name): return getattr(self.field, name) - def metadata(self, *args, **kwargs): + def metadata(self, *args, remapping=None, patches=None, **kwargs): + if remapping is not None or patches is not None: + from earthkit.data.core.order import build_remapping + + remapping = build_remapping(remapping, patches) + return remapping(self.metadata)(*args, **kwargs) + if len(args) > 0 and args[0] == "levelist": - return None + if "default" in kwargs: + return kwargs["default"] + raise KeyError("levelist") return self.field.metadata(*args, **kwargs) + def __repr__(self) -> str: + return repr(self.field) + fixed = [] for field in input_fields: if field.metadata("levtype") in ("sfc", "o2d") and field.metadata("edition") == 2: @@ -140,39 +151,12 @@ def run( autocast = AUTOCAST[autocast] - params_before = set(f.metadata("param") for f in input_fields) - input_fields = _fix_eccodes_bug_for_levtype_sfc_and_grib2(input_fields) - before = {id(x): x for x in input_fields} LOGGER.info("Selecting fields %s", self.checkpoint.select) input_fields = input_fields.sel(**self.checkpoint.select) LOGGER.info("Selected fields: %s", len(input_fields)) - params_after = set(f.metadata("param") for f in input_fields) - - after = {} - for x in input_fields: - if hasattr(x, "field"): - after[id(x.field)] = x - else: - after[id(x)] = x - - LOGGER.info("Input fields before/after %s %s", params_before - params_after, params_after - params_before) - - if len(before) != len(after): - LOGGER.error("Input fields before=%s after=%s", len(before), len(after)) - param_level = self.checkpoint.select["param_level"] - for i, f in before.items(): - if i not in after: - if f.metadata("param") in param_level: - LOGGER.error( - "Field %s %s grib%s missing", - f.metadata("param"), - f.metadata("levelist", default="?"), - f.metadata("edition", default="?"), - ) - input_fields = input_fields.order_by(**self.checkpoint.order_by) number_of_grid_points = len(input_fields[0].grid_points()[0]) @@ -222,6 +206,7 @@ def run( assert not np.any(check[computed_constant_mask]), check check[computed_constant_mask] = True kinds[computed_constant_mask] = "C" + self._report("computed_constant_mask", computed_constant_mask) # E.g. lsm, orography constant_from_input_mask = self.checkpoint.constants_from_input_mask @@ -229,12 +214,14 @@ def run( check[constant_from_input_mask] = True kinds[constant_from_input_mask] = "K" inputs[constant_from_input_mask] = True + self._report("constant_from_input_mask", constant_from_input_mask) # e.g. isolation computed_forcing_mask = self.checkpoint.computed_forcings_mask assert not np.any(check[computed_forcing_mask]), check check[computed_forcing_mask] = True kinds[computed_forcing_mask] = "F" + self._report("computed_forcing_mask", computed_forcing_mask) # e.g 2t, 10u, 10v prognostic_input_mask = self.checkpoint.prognostic_input_mask @@ -242,6 +229,7 @@ def run( check[prognostic_input_mask] = True kinds[prognostic_input_mask] = "P" inputs[prognostic_input_mask] = True + self._report("prognostic_input_mask", prognostic_input_mask) # if not np.all(check): @@ -282,7 +270,10 @@ def run( ) prognostic_data_from_retrieved_fields_mask = np.array(prognostic_data_from_retrieved_fields_mask) + self._report("prognostic_data_from_retrieved_fields_mask", prognostic_data_from_retrieved_fields_mask) + constant_data_from_retrieved_fields_mask = np.array(constant_data_from_retrieved_fields_mask) + self._report("constant_data_from_retrieved_fields_mask", constant_data_from_retrieved_fields_mask) # Build the input tensor @@ -305,20 +296,41 @@ def run( LOGGER.error("prognostic_data_from_retrieved_fields_mask: %s", prognostic_data_from_retrieved_fields_mask) raise ValueError("Mismatch in prognostic_input_mask and prognostic_data_from_retrieved_fields_mask") - if len(prognostic_data_from_retrieved_fields_mask) != input_fields_numpy.shape[1]: - LOGGER.error("Mismatch in prognostic_data_from_retrieved_fields_mask and input_fields_numpy") - LOGGER.error( - "prognostic_data_from_retrieved_fields_mask: %s shape=(variables)", - prognostic_data_from_retrieved_fields_mask.shape, + if len(prognostic_data_from_retrieved_fields_mask) > input_fields_numpy.shape[1]: + self._report_mismatch( + "prognostic_data_from_retrieved_fields_mask", + prognostic_data_from_retrieved_fields_mask, + input_fields, + input_fields_numpy, ) - LOGGER.error("input_fields_numpy: %s shape=(dates, variables, grid)", input_fields_numpy.shape) - raise ValueError("Mismatch in prognostic_input_mask and input_fields_numpy") input_tensor_numpy[:, prognostic_input_mask] = input_fields_numpy[:, prognostic_data_from_retrieved_fields_mask] - input_tensor_numpy[:, constant_from_input_mask] = input_fields_numpy[ - :, constant_data_from_retrieved_fields_mask - ] + if len(constant_data_from_retrieved_fields_mask) > input_fields_numpy.shape[1]: + self._report_mismatch( + "constant_data_from_retrieved_fields_mask", + constant_data_from_retrieved_fields_mask, + input_fields, + input_fields_numpy, + ) + + if len(constant_from_input_mask) != len(constant_data_from_retrieved_fields_mask): + LOGGER.error("Mismatch in constant_from_input_mask and constant_data_from_retrieved_fields_mask") + LOGGER.error("constant_from_input_mask: %s", constant_from_input_mask) + LOGGER.error("constant_data_from_retrieved_fields_mask: %s", constant_data_from_retrieved_fields_mask) + raise ValueError("Mismatch in constant_from_input_mask and constant_data_from_retrieved_fields_mask") + + try: + input_tensor_numpy[:, constant_from_input_mask] = input_fields_numpy[ + :, constant_data_from_retrieved_fields_mask + ] + except IndexError: + self._report_mismatch( + "constant_data_from_retrieved_fields_mask", + constant_data_from_retrieved_fields_mask, + input_fields, + input_fields_numpy, + ) constants = forcing_and_constants( source=input_fields[:1], @@ -523,6 +535,23 @@ def param_level_ml(self): return self.checkpoint.param_level_ml + def _report_mismatch(self, name, mask, input_fields, input_fields_numpy): + LOGGER.error("Mismatch in %s and input_fields", name) + LOGGER.error("%s: %s shape=(variables)", name, mask.shape) + LOGGER.error("input_fields_numpy: %s shape=(dates, variables, grid)", input_fields_numpy.shape) + LOGGER.error("MASK : %s", [self.checkpoint.variables[_] for _ in mask]) + + remapping = self.checkpoint.select["remapping"] + names = list(remapping.keys()) + + LOGGER.error( + "INPUT: %s", [input_fields[i].metadata(*names, remapping=remapping) for i in range(len(input_fields) // 2)] + ) + raise ValueError(f"Mismatch in {name} and input_fields") + + def _report(self, name, mask): + LOGGER.info("%s: %s", name, [self.checkpoint.variables[_] for _ in mask]) + class DefaultRunner(Runner): """_summary_ From d415adf9399c40f0877f80ae8b82d20a37b516cb Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 20 Oct 2024 19:41:10 +0000 Subject: [PATCH 7/8] Fix pre-commit regex --- .pre-commit-config.yaml | 3 +-- CHANGELOG.md | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9cb0fcf..4de5932 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,13 +43,12 @@ repos: rev: v0.6.9 hooks: - id: ruff - # Next line if for documenation cod snippets - exclude: '.*/[^_].*_\.py$' args: - --line-length=120 - --fix - --exit-non-zero-on-fix - --preview + - --exclude=docs/**/*_.py - repo: https://github.com/sphinx-contrib/sphinx-lint rev: v1.0.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b7bea2..92123a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,14 +11,15 @@ Keep it human-readable, your future self will thank you! ## [Unreleased] ### Added -- Fix: Enable inference when no constant forcings are used +- Fix: Enable inference when no constant forcings are used - Add anemoi-transform link to documentation ### Changed - Add cos_solar_zenith_angle to list of known forcings - Add missing classes in checkpoint handling +- Fix pre-commit regex ### Removed From 53b4eba16c6d09bc04e2a357ee7f0b9504078918 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Oct 2024 19:41:30 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/inference/checkpoint/__init__.py | 2 +- src/anemoi/inference/checkpoint/metadata/__init__.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anemoi/inference/checkpoint/__init__.py b/src/anemoi/inference/checkpoint/__init__.py index f1e54cf..c314885 100644 --- a/src/anemoi/inference/checkpoint/__init__.py +++ b/src/anemoi/inference/checkpoint/__init__.py @@ -206,5 +206,5 @@ def validate_environment( raise ValueError(f"Invalid value for `on_difference`: {on_difference}") return False - LOG.info(f"Environment validation passed") + LOG.info("Environment validation passed") return True diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index da4949c..33bd5bc 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -345,7 +345,6 @@ def rounded_area(self, area): return area def report_loading_error(self): - import json if "provenance_training" not in self._metadata: return