Skip to content

Commit

Permalink
add debug info
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 15, 2024
1 parent 9700121 commit 7c0e816
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@


import datetime
import logging

from anemoi.utils.dates import as_datetime

from ..runner import DefaultRunner
from . import Command

LOGGER = logging.getLogger(__name__)


class RunCmd(Command):
"""Inspect the contents of a checkpoint file."""
Expand Down Expand Up @@ -52,6 +55,8 @@ def run(self, args):

input_fields += ekd.from_source("mars", r)

LOGGER.info("Running the model with the following %s fields, for %s dates", len(input_fields), len(dates))

runner.run(input_fields=input_fields, lead_time=240, device="cuda")


Expand Down
76 changes: 76 additions & 0 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,33 @@ def ignore(*args, **kwargs):
pass


def _fix_eccodes_bug_for_levtype_sfc_and_grib2(input_fields):
from earthkit.data.indexing.fieldlist import FieldArray

class BugFix:
def __init__(self, field):
self.field = field

def __getattr__(self, name):
return getattr(self.field, name)

def metadata(self, *args, **kwargs):
if len(args) > 0 and args[0] == "levelist":
return None
return self.field.metadata(*args, **kwargs)

fixed = []
for field in input_fields:
if field.metadata("levtype") in ("sfc", "o2d") and field.metadata("edition") == 2:
if field.metadata("levelist", default=None) is not None:
# LOGGER.warning("Fixing eccodes bug for levtype=sfc and grib2 %s", field)
fixed.append(BugFix(field))
else:
fixed.append(field)

return FieldArray(fixed)


class Runner:
"""_summary_"""

Expand Down Expand Up @@ -113,7 +140,39 @@ def run(

autocast = AUTOCAST[autocast]

params_before = set(f.metadata("param") for f in input_fields)

input_fields = _fix_eccodes_bug_for_levtype_sfc_and_grib2(input_fields)

before = {id(x): x for x in input_fields}
LOGGER.info("Selecting fields %s", self.checkpoint.select)
input_fields = input_fields.sel(**self.checkpoint.select)
LOGGER.info("Selected fields: %s", len(input_fields))

params_after = set(f.metadata("param") for f in input_fields)

after = {}
for x in input_fields:
if hasattr(x, "field"):
after[id(x.field)] = x
else:
after[id(x)] = x

LOGGER.info("Input fields before/after %s %s", params_before - params_after, params_after - params_before)

if len(before) != len(after):
LOGGER.error("Input fields before=%s after=%s", len(before), len(after))
param_level = self.checkpoint.select["param_level"]
for i, f in before.items():
if i not in after:
if f.metadata("param") in param_level:
LOGGER.error(
"Field %s %s grib%s missing",
f.metadata("param"),
f.metadata("levelist", default="?"),
f.metadata("edition", default="?"),
)

input_fields = input_fields.order_by(**self.checkpoint.order_by)

number_of_grid_points = len(input_fields[0].grid_points()[0])
Expand Down Expand Up @@ -149,6 +208,8 @@ def run(
number_of_grid_points,
) # nlags, nparams, ngrid

LOGGER.info("Input fields shape: %s (dates, variables, grid)", input_fields_numpy.shape)

# Used to check if we cover the whole input, with no overlaps
check = np.full(self.checkpoint.num_input_features, fill_value=False, dtype=np.bool_)

Expand Down Expand Up @@ -238,6 +299,21 @@ def run(
# Check that the computed constant mask and the constant from input mask are disjoint
# assert np.amax(prognostic_input_mask) < np.amin(constant_from_input_mask)

if len(prognostic_input_mask) != len(prognostic_data_from_retrieved_fields_mask):
LOGGER.error("Mismatch in prognostic_input_mask and prognostic_data_from_retrieved_fields_mask")
LOGGER.error("prognostic_input_mask: %s", prognostic_input_mask)
LOGGER.error("prognostic_data_from_retrieved_fields_mask: %s", prognostic_data_from_retrieved_fields_mask)
raise ValueError("Mismatch in prognostic_input_mask and prognostic_data_from_retrieved_fields_mask")

if len(prognostic_data_from_retrieved_fields_mask) != input_fields_numpy.shape[1]:
LOGGER.error("Mismatch in prognostic_data_from_retrieved_fields_mask and input_fields_numpy")
LOGGER.error(
"prognostic_data_from_retrieved_fields_mask: %s shape=(variables)",
prognostic_data_from_retrieved_fields_mask.shape,
)
LOGGER.error("input_fields_numpy: %s shape=(dates, variables, grid)", input_fields_numpy.shape)
raise ValueError("Mismatch in prognostic_input_mask and input_fields_numpy")

input_tensor_numpy[:, prognostic_input_mask] = input_fields_numpy[:, prognostic_data_from_retrieved_fields_mask]

input_tensor_numpy[:, constant_from_input_mask] = input_fields_numpy[
Expand Down

0 comments on commit 7c0e816

Please sign in to comment.