diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index c44d7a2..9035c7d 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -11,7 +11,6 @@ import numpy as np import semantic_version -from anemoi.utils.text import dotted_line LOG = logging.getLogger(__name__) @@ -63,6 +62,11 @@ def from_metadata(cls, metadata): if "arguments" not in metadata["dataset"]: metadata["dataset"]["version"] = "0.1.0" + # When we changed from ecml_tools to anemoi-datasets, we went back in the + # versionning + if metadata["dataset"]["version"] in ("0.1.7", "0.1.8", "0.1.9"): + metadata["dataset"]["version"] = "0.2.0" + klass = from_versions(metadata["version"], metadata["dataset"]["version"]) return klass(metadata) @@ -294,7 +298,42 @@ def multi_step(self): @cached_property def imputable_variables(self): - return self._config_data.get("imputer", []) + result = [] + + def from_input_imputer(config): + for k, v in config.items(): + if not isinstance(v, list): + v = [v] + yield from v + + def from_constant_imputer(config): + yield from config.keys() + + def empty(config): + return [] + + IMPUTERS = { + "aifs.preprocessing.imputer.InputImputer": from_input_imputer, + "aifs.preprocessing.imputer.ConstantImputer": from_constant_imputer, + } + + for k, v in self._config_data.get("processors", {}).items(): + target = v.get("_target_") + if target is None: + continue + + if "imputer" in target.lower() and target not in IMPUTERS: + LOG.warning("Unknown imputer %s, ignoring", target) + continue + + source = IMPUTERS.get(target, empty) + result.extend(source(v.get("config", {}))) + + result = sorted(set(result)) + if result: + LOG.info("Imputable variables %s", result) + + return result def rounded_area(self, area): surface = (area[0] - area[2]) * (area[3] - area[1]) / 180 / 360 @@ -311,23 +350,3 @@ def report_loading_error(self): provenance_training = self._metadata["provenance_training"] LOG.error("Training provenance:\n%s", json.dumps(provenance_training, indent=2)) - - ########################################################################### - def describe(self): - print("num_input_features:", self.num_input_features) - print("hour_steps:", self.hour_steps) - result = list(range(0, self.multi_step)) - result = [-s * self.hour_steps for s in result] - print(sorted(result)) - print("multi_step:", self.multi_step) - print() - print("MARS requests:") - print(dotted_line()) - print("param_sfc:", self.param_sfc) - print("param_level_pl:", self.param_level_pl) - print("param_level_ml:", self.param_level_ml) - print("prognostic_params:", self.prognostic_params) - print("diagnostic_params:", self.diagnostic_params) - print("constants_from_input:", self.constants_from_input) - print("computed_constants:", self.computed_constants) - print("computed_forcings:", self.computed_forcings)