Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/diagnostics #18

Merged
merged 4 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Keep it human-readable, your future self will thank you!
- Add Condition to store data [#15](https://github.com/ecmwf/anemoi-inference/pull/15)

### Changed
- Fix: diagnostics bug when fields are non-accumulated, remove diagnostics from mars request [#18](https://github.com/ecmwf/anemoi-inference/pull/18)
- ci: updated workflows on PR and releases to use reusable actions

### Removed
Expand Down
56 changes: 50 additions & 6 deletions src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# nor does it submit to any jurisdiction.


import json
import logging
from functools import cached_property

Expand Down Expand Up @@ -155,7 +156,12 @@ def select(self):
# order = self._dataset["order_by"]
return dict(
# valid_datetime="ascending",
param_level=self.variables,
param_level=sorted(
set(self.variables)
- set(self.computed_constants)
- set(self.computed_forcings)
- set(self.diagnostic_params)
),
# ensemble=self.checkpoint.ordering('ensemble'),
remapping={"param_level": "{param}_{levelist}"},
)
Expand Down Expand Up @@ -204,7 +210,7 @@ def _computed_constants(self):

LOG.debug("computed_constants data_mask: %s", data_mask)
LOG.debug("computed_constants model_mask: %s", model_mask)
LOG.info("Computed constants: %s", names)
LOG.debug("Computed constants: %s", names)

return data_mask, model_mask, names

Expand All @@ -230,8 +236,6 @@ def _computed_forcings(self):
]
)

print("FORCINGS", self._forcing_params())

constants = set(self._forcing_params()) - set(self.constants_from_input) - set(self.computed_constants)

if constants - known:
Expand All @@ -241,7 +245,7 @@ def _computed_forcings(self):

LOG.debug("computed_forcing data_mask: %s", data_mask)
LOG.debug("computed_forcing model_mask: %s", model_mask)
LOG.info("Computed forcings: %s", names)
# LOG.info("Computed forcings: %s", names)

return data_mask, model_mask, names

Expand All @@ -265,7 +269,7 @@ def _constants_from_input(self):

LOG.debug("constants_from_input: %s", data_mask)
LOG.debug("constants_from_input: %s", model_mask)
LOG.info("Constants from input: %s", names)
LOG.debug("Constants from input: %s", names)

return data_mask, model_mask, names

Expand Down Expand Up @@ -307,6 +311,11 @@ def diagnostic_params(self):
def prognostic_params(self):
return [self.index_to_variable[i] for i in self._indices["data"]["input"]["prognostic"]]

@cached_property
def accumulations_params(self):
# We assume that accumulations are the ones that are forecasts
return sorted(p[0] for p in self.param_step_sfc_pairs)

###########################################################################
@cached_property
def precision(self):
Expand Down Expand Up @@ -349,3 +358,38 @@ def predict_step_shape(self):
self.number_of_grid_points, # Grid points
self.num_input_features, # Fields
)

###########################################################################
def summary(self):

print(f"Prognostics: ({len(self.prognostic_params)})")
print(sorted(self.prognostic_params))
print()

print(f"Diagnostics: ({len(self.diagnostic_params)})")
print(sorted(self.diagnostic_params))
print()

print(f"Retrieved constants: ({len(self.constants_from_input)})")
print(sorted(self.constants_from_input))
print()

print(f"Computed constants: ({len(self.computed_constants)})")
print(sorted(self.computed_constants))
print()

print(f"Computed forcings: ({len(self.computed_forcings)})")
print(sorted(self.computed_forcings))
print()

print(f"Accumulations: ({len(self.accumulations_params)})")
print(sorted(self.accumulations_params))
print()

# print("Select:")
# print(json.dumps(self.select, indent=2))
# print()

# print("Order by:")
# print(json.dumps(self.order_by, indent=2))
# print()
6 changes: 3 additions & 3 deletions src/anemoi/inference/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def _output(self, *args, **kwargs):

@property
def param_sfc(self):
return self.runner.checkpoint.param_sfc
return self.runner.param_sfc

@property
def param_level_pl(self):
return self.runner.checkpoint.param_level_pl
return self.runner.param_level_pl

@property
def param_level_ml(self):
return self.runner.checkpoint.param_level_ml
return self.runner.param_level_ml

@property
def constant_fields(self):
Expand Down
62 changes: 49 additions & 13 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def run(
_description_
"""

self.checkpoint.summary()

if autocast is None:
autocast = self.checkpoint.precision

Expand Down Expand Up @@ -310,7 +312,7 @@ def get_most_recent_datetime(input_fields):

most_recent_datetime = get_most_recent_datetime(input_fields)
reference_fields = [f for f in input_fields if f.datetime()["valid_time"] == most_recent_datetime]
precip_template = reference_fields[self.checkpoint.variable_to_index["lsm"]]
prognostic_template = reference_fields[self.checkpoint.variable_to_index["lsm"]]

accumulated_output = np.zeros(
shape=(len(diagnostic_output_mask), number_of_grid_points),
Expand All @@ -321,13 +323,14 @@ def get_most_recent_datetime(input_fields):
output_callback(
input_fields,
self.checkpoint.diagnostic_params,
precip_template,
prognostic_template,
accumulated_output[0].shape,
)
else:
output_callback(input_fields)

prognostic_params = self.checkpoint.prognostic_params
accumulations_params = self.checkpoint.accumulations_params

# with self.stepper(self.hour_steps) as stepper:

Expand Down Expand Up @@ -362,19 +365,28 @@ def get_most_recent_datetime(input_fields):
if len(diagnostic_output_mask):
for n, param in enumerate(self.checkpoint.diagnostic_params):
accumulated_output[n] += np.maximum(0, diagnostic_fields_numpy[:, n])
assert precip_template.datetime()["valid_time"] == most_recent_datetime, (
precip_template.datetime()["valid_time"],
assert prognostic_template.datetime()["valid_time"] == most_recent_datetime, (
prognostic_template.datetime()["valid_time"],
most_recent_datetime,
)
output_callback(
accumulated_output[n],
stepType="accum",
template=precip_template,
startStep=0,
endStep=step,
param=param,
check_nans=True, # param in can_be_missing,
)

if param in accumulations_params:
output_callback(
accumulated_output[n],
stepType="accum",
template=prognostic_template,
startStep=0,
endStep=step,
param=param,
check_nans=True, # param in can_be_missing,
)
else:
output_callback(
diagnostic_fields_numpy[:, n],
template=prognostic_template,
step=step,
check_nans=True, # param in can_be_missing,
)

# Next step

Expand Down Expand Up @@ -407,6 +419,30 @@ def lagged(self):
result = [-s * self.hour_steps for s in result]
return sorted(result)

@property
def param_sfc(self):
param_sfc = self.checkpoint.param_sfc

# Remove diagnostic params

param_sfc = [p for p in param_sfc if p not in self.checkpoint.diagnostic_params]

return param_sfc

@property
def param_level_pl(self):

# To do remove diagnostic params

return self.checkpoint.param_level_pl

@property
def param_level_ml(self):

# To do remove diagnostic params

return self.checkpoint.param_level_ml


class DefaultRunner(Runner):
"""_summary_
Expand Down
Loading