diff --git a/pyproject.toml b/pyproject.toml index 109e098..95bbff6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ classifiers = [ "Operating System :: OS Independent", ] -dependencies = ["anemoi-utils"] +dependencies = ["anemoi-utils", "semantic-version"] [project.optional-dependencies] diff --git a/src/anemoi/inference/checkpoint/__init__.py b/src/anemoi/inference/checkpoint/__init__.py index 2fbc37f..c88f558 100644 --- a/src/anemoi/inference/checkpoint/__init__.py +++ b/src/anemoi/inference/checkpoint/__init__.py @@ -40,7 +40,7 @@ def _checkpoint_metadata(self, name): @cached_property def operational_config(self): try: - result = self._checkpoint_metadata("operational-config.json") + result = load_metadata(self.path, "operational-config.json") LOG.info(f"Using operational configuration from checkpoint {self.path}") return result except ValueError: diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index 9035c7d..caaff64 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -51,6 +51,9 @@ class Metadata: def __init__(self, metadata): self._metadata = metadata + def to_dict(self): + return self._metadata + @classmethod def from_metadata(cls, metadata): if isinstance(metadata["dataset"], list): @@ -292,6 +295,10 @@ def prognostic_params(self): return [self.index_to_variable[i] for i in self._indices["data"]["input"]["prognostic"]] ########################################################################### + @cached_property + def precision(self): + return self._config_training["precision"] + @cached_property def multi_step(self): return self._config_training["multistep_input"] diff --git a/src/anemoi/inference/commands/checkpoint.py b/src/anemoi/inference/commands/checkpoint.py index acfeb64..a9fb17e 100644 --- a/src/anemoi/inference/commands/checkpoint.py +++ b/src/anemoi/inference/commands/checkpoint.py @@ -9,6 +9,8 @@ # +import json + from ..checkpoint import Checkpoint from . import Command @@ -18,11 +20,15 @@ class CheckpointCmd(Command): need_logging = False def add_arguments(self, command_parser): + command_parser.add_argument("--json", action="store_true", help="Output in JSON format") command_parser.add_argument("path", help="Path to the checkpoint.") def run(self, args): c = Checkpoint(args.path) + if args.json: + print(json.dumps(c.to_dict(), indent=4, sort_keys=True)) + return print("area:", c.area) print("computed_constants:", c.computed_constants) @@ -39,6 +45,7 @@ def run(self, args): print("from_metadata:", c.from_metadata) print("grid:", c.grid) print("hour_steps:", c.hour_steps) + print("imputable variables", c.imputable_variables) print("imputable variables:", c.imputable_variables) print("imputable_variables:", c.imputable_variables) print("index_to_variable:", c.index_to_variable) @@ -50,6 +57,7 @@ def run(self, args): print("param_level_ml:", c.param_level_ml) print("param_level_pl:", c.param_level_pl) print("param_sfc:", c.param_sfc) + print("precision", c.precision) print("prognostic_data_input_mask:", c.prognostic_data_input_mask) print("prognostic_input_mask:", c.prognostic_input_mask) print("prognostic_output_mask:", c.prognostic_output_mask) @@ -59,10 +67,6 @@ def run(self, args): print("select:", c.select) print("variable_to_index:", c.variable_to_index) print("variables:", c.variables) - print() - result = list(range(0, c.multi_step)) - result = [-s * c.hour_steps for s in result] - print(sorted(result)) command = CheckpointCmd diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py new file mode 100644 index 0000000..294929f --- /dev/null +++ b/src/anemoi/inference/runner.py @@ -0,0 +1,381 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +from functools import cached_property + +import numpy as np +import torch +from anemoi.utils.timer import Timer + +from .checkpoint import Checkpoint + +LOGGER = logging.getLogger(__name__) + + +AUTOCAST = { + "16": torch.float16, + "16-mixed": torch.float16, + "32": torch.float32, + "b16": torch.bfloat16, + "b16-mixed": torch.bfloat16, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + None: torch.float16, +} + + +def forcing_and_constants(source, date, param): + import climetlab as cml + + ds = cml.load_source( + "constants", + source, + date=date, + param=param, + ) + + assert len(ds) == len(param), (len(ds), len(param), date) + + return ds.to_numpy(dtype=np.float32) + + +def ignore(*args, **kwargs): + pass + + +class Runner: + + def __init__(self, checkpoint): + self.checkpoint = Checkpoint(checkpoint) + + def run( + self, + *, + input_fields, + lead_time, + device, + start_datetime=None, + output_callback=ignore, + autocast=None, + progress_callback=ignore, + ): + + if autocast is None: + autocast = self.checkpoint.precision + + autocast = AUTOCAST[autocast] + + input_fields = input_fields.sel(**self.checkpoint.select) + input_fields = input_fields.order_by(**self.checkpoint.order_by) + + number_of_grid_points = len(input_fields[0].grid_points()[0]) + + LOGGER.info("Loading input: %d fields (lagged=%d)", len(input_fields), len(self.lagged)) + + input_fields_numpy = input_fields.to_numpy(dtype=np.float32, reshape=False) + print(input_fields_numpy.shape) + + input_fields_numpy = input_fields_numpy.reshape( + len(self.lagged), + len(input_fields) // len(self.lagged), + number_of_grid_points, + ) # nlags, nparams, ngrid + + # Used to check if we cover the whole input, with no overlaps + check = np.full(self.checkpoint.num_input_features, fill_value=False, dtype=np.bool_) + + kinds = np.full(self.checkpoint.num_input_features, fill_value="?", dtype=np.character) + + inputs = np.full(self.checkpoint.num_input_features, fill_value=False, dtype=np.bool_) + + # E.g cos_latitude + computed_constant_mask = self.checkpoint.computed_constants_mask + assert not np.any(check[computed_constant_mask]), check + check[computed_constant_mask] = True + kinds[computed_constant_mask] = "C" + + # E.g. lsm, orography + constant_from_input_mask = self.checkpoint.constants_from_input_mask + assert not np.any(check[constant_from_input_mask]), check + check[constant_from_input_mask] = True + kinds[constant_from_input_mask] = "K" + inputs[constant_from_input_mask] = True + + # e.g. isolation + computed_forcing_mask = self.checkpoint.computed_forcings_mask + assert not np.any(check[computed_forcing_mask]), check + check[computed_forcing_mask] = True + kinds[computed_forcing_mask] = "F" + + # e.g 2t, 10u, 10v + prognostic_input_mask = self.checkpoint.prognostic_input_mask + assert not np.any(check[prognostic_input_mask]), check + check[prognostic_input_mask] = True + kinds[prognostic_input_mask] = "P" + inputs[prognostic_input_mask] = True + + # + if not np.all(check): + for i, c in enumerate(check): + if not c: + LOGGER.error( + "Missing %s %s %s", + i, + self.checkpoint.model_to_data[i], + self.checkpoint.index_to_variable[self.checkpoint.model_to_data[i]], + ) + raise RuntimeError("Missing fields") + + prognostic_data_from_retrieved_fields_mask = [] + constant_data_from_retrieved_fields_mask = [] + + MARS = {False: " ", True: "X"} + + retrieved_fields_index = 0 + for i, c in enumerate(check): + if inputs[i]: + assert kinds[i].decode() in ("P", "K") + if kinds[i].decode() == "P": + prognostic_data_from_retrieved_fields_mask.append(retrieved_fields_index) + else: + constant_data_from_retrieved_fields_mask.append(retrieved_fields_index) + retrieved_fields_index += 1 + + if hasattr(self, "verbose") and self.verbose: + print( + "{:4d} {:1s} {} {:4d} {:10s}".format( + i, + kinds[i].decode(), + MARS[inputs[i]], + self.checkpoint.model_to_data[i], + self.checkpoint.index_to_variable[self.checkpoint.model_to_data[i]], + ) + ) + + prognostic_data_from_retrieved_fields_mask = np.array(prognostic_data_from_retrieved_fields_mask) + constant_data_from_retrieved_fields_mask = np.array(constant_data_from_retrieved_fields_mask) + + # Build the input tensor + + input_tensor_numpy = np.full( + shape=( + len(self.lagged), + self.checkpoint.num_input_features, + number_of_grid_points, + ), + fill_value=np.nan, + dtype=np.float32, + ) # nlags, nparams, ngrid + + # Check that the computed constant mask and the constant from input mask are disjoint + # assert np.amax(prognostic_input_mask) < np.amin(constant_from_input_mask) + + input_tensor_numpy[:, prognostic_input_mask] = input_fields_numpy[:, prognostic_data_from_retrieved_fields_mask] + + input_tensor_numpy[:, constant_from_input_mask] = input_fields_numpy[ + :, constant_data_from_retrieved_fields_mask + ] + + if start_datetime is None: + start_datetime = input_fields.order_by(valid_datetime="ascending")[-1].datetime() + + constants = forcing_and_constants( + source=input_fields[:1], + param=self.checkpoint.computed_constants, + date=start_datetime, + ) + + for i in range(len(self.lagged)): + input_tensor_numpy[i, computed_constant_mask] = constants + + for i in range(len(self.lagged)): + forcings = forcing_and_constants( + source=input_fields[:1], + param=self.checkpoint.computed_forcings, + date=start_datetime + datetime.timedelta(hours=self.lagged[i]), + ) + input_tensor_numpy[i, computed_forcing_mask] = forcings + + LOGGER.info("Input tensor shape: %s", input_tensor_numpy.shape) + + imputable_variables = self.checkpoint.imputable_variables + can_be_missing = set() + # check for NaNs + for i in range(input_tensor_numpy.shape[1]): + name = self.checkpoint.index_to_variable[self.checkpoint.model_to_data[i]] + has_missing = np.isnan(input_tensor_numpy[:, i, :]).any() + is_imputable = name in imputable_variables + if has_missing: + can_be_missing.add(name) + if not is_imputable: + model_index = self.checkpoint.model_to_data[i] + LOGGER.error( + "No imputation specified for NaNs in %s (%s %s)", + name, + i, + model_index, + ) + raise ValueError(f"Field '{name}' has NaNs and is not marked as imputable") + + with Timer(f"Loading {self.checkpoint}"): + try: + model = torch.load( + self.checkpoint.path, + map_location=device, + ).to(device) + except Exception: + self.checkpoint.report_loading_error() + raise + + model.eval() + + torch.set_grad_enabled(False) + + input_tensor_torch = torch.from_numpy( + np.swapaxes( + input_tensor_numpy, + -2, + -1, + )[np.newaxis, ...] + ).to(device) + + prognostic_output_mask = self.checkpoint.prognostic_output_mask + diagnostic_output_mask = self.checkpoint.diagnostic_output_mask + + LOGGER.info("Using autocast %s", autocast) + + # Write dynamic fields + def get_most_recent_datetime(input_fields): + datetimes = [f.valid_datetime() for f in input_fields] + latest = datetimes[-1] + for d in datetimes: + assert d <= latest, (datetimes, d, latest) + return latest + + most_recent_datetime = get_most_recent_datetime(input_fields) + reference_fields = [f for f in input_fields if f.valid_datetime() == most_recent_datetime] + precip_template = reference_fields[self.checkpoint.variable_to_index["2t"]] + + accumulated_output = np.zeros( + shape=(len(diagnostic_output_mask), number_of_grid_points), + dtype=np.float32, + ) + + if self.checkpoint.diagnostic_params: + output_callback( + input_fields, + self.checkpoint.diagnostic_params, + precip_template, + accumulated_output[0].shape, + ) + else: + output_callback(input_fields) + + def add_ensemble_dim(func): + def wrapper(x): + x = x.unsqueeze(2) + y = func(x) + return y.squeeze_(2) + + return wrapper + + if True: # self.add_ensemble_dimension: + LOGGER.warning("🚨" * 80) + LOGGER.warning("Adding ensemble dimension.") + LOGGER.warning("If you are using that flags, your are using unsupported code") + LOGGER.warning("that can be removed any time in the near future") + LOGGER.warning("🚨" * 80) + + model.predict_step = add_ensemble_dim(model.predict_step) + + prognostic_params = self.checkpoint.prognostic_params + + # with self.stepper(self.hour_steps) as stepper: + + for i in progress_callback(range(lead_time // self.hour_steps)): + step = (i + 1) * self.hour_steps + + # Predict next state of atmosphere + with torch.autocast(device_type=device, dtype=autocast): + y_pred = model.predict_step(input_tensor_torch) + + # Detach tensor and squeeze + output = np.squeeze(y_pred.cpu().numpy()) + + prognostic_fields_numpy = output[:, prognostic_output_mask] + if len(diagnostic_output_mask): + diagnostic_fields_numpy = output[:, diagnostic_output_mask] + + for n, (m, param) in enumerate(zip(prognostic_data_from_retrieved_fields_mask, prognostic_params)): + template = reference_fields[m] + assert template.valid_datetime() == most_recent_datetime, ( + template.valid_datetime(), + most_recent_datetime, + ) + output_callback( + prognostic_fields_numpy[:, n], + template=template, + step=step, + check_nans=True, # param in can_be_missing, + ) + + # Write diagnostic fields + if len(diagnostic_output_mask): + for n, param in enumerate(self.checkpoint.diagnostic_params): + accumulated_output[n] += np.maximum(0, diagnostic_fields_numpy[:, n]) + assert precip_template.valid_datetime() == most_recent_datetime, ( + precip_template.valid_datetime(), + most_recent_datetime, + ) + output_callback( + accumulated_output[n], + stepType="accum", + template=precip_template, + startStep=0, + endStep=step, + param=param, + check_nans=True, # param in can_be_missing, + ) + + # Next step + + prognostic_fields = y_pred[..., prognostic_output_mask] + + # Compute new forcing + + forcing = forcing_and_constants( + source=input_fields[:1], + param=self.checkpoint.computed_forcings, + date=start_datetime + datetime.timedelta(hours=step), + ) + forcing = np.swapaxes(forcing[np.newaxis, np.newaxis, ...], -2, -1) + forcing = torch.from_numpy(forcing).to(device) + + # Update dynamic tensor for next iteration + input_tensor_torch = input_tensor_torch.roll(-1, dims=1) + input_tensor_torch[:, -1, :, prognostic_input_mask] = prognostic_fields + input_tensor_torch[:, -1, :, computed_forcing_mask] = forcing + + # progress_callback(i) + + @cached_property + def hour_steps(self): + return self.checkpoint.hour_steps + + @cached_property + def lagged(self): + result = list(range(0, self.checkpoint.multi_step)) + result = [-s * self.hour_steps for s in result] + return sorted(result) + + +class DefaultRunner(Runner): + pass