Skip to content

Commit

Permalink
support emputers
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 23, 2024
1 parent 2b17ae4 commit 4ab2d62
Showing 1 changed file with 41 additions and 22 deletions.
63 changes: 41 additions & 22 deletions src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import numpy as np
import semantic_version
from anemoi.utils.text import dotted_line

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,6 +62,11 @@ def from_metadata(cls, metadata):
if "arguments" not in metadata["dataset"]:
metadata["dataset"]["version"] = "0.1.0"

# When we changed from ecml_tools to anemoi-datasets, we went back in the
# versionning
if metadata["dataset"]["version"] in ("0.1.7", "0.1.8", "0.1.9"):
metadata["dataset"]["version"] = "0.2.0"

klass = from_versions(metadata["version"], metadata["dataset"]["version"])
return klass(metadata)

Expand Down Expand Up @@ -294,7 +298,42 @@ def multi_step(self):

@cached_property
def imputable_variables(self):
return self._config_data.get("imputer", [])
result = []

def from_input_imputer(config):
for k, v in config.items():
if not isinstance(v, list):
v = [v]
yield from v

def from_constant_imputer(config):
yield from config.keys()

def empty(config):
return []

IMPUTERS = {
"aifs.preprocessing.imputer.InputImputer": from_input_imputer,
"aifs.preprocessing.imputer.ConstantImputer": from_constant_imputer,
}

for k, v in self._config_data.get("processors", {}).items():
target = v.get("_target_")
if target is None:
continue

if "imputer" in target.lower() and target not in IMPUTERS:
LOG.warning("Unknown imputer %s, ignoring", target)
continue

source = IMPUTERS.get(target, empty)
result.extend(source(v.get("config", {})))

result = sorted(set(result))
if result:
LOG.info("Imputable variables %s", result)

return result

def rounded_area(self, area):
surface = (area[0] - area[2]) * (area[3] - area[1]) / 180 / 360
Expand All @@ -311,23 +350,3 @@ def report_loading_error(self):
provenance_training = self._metadata["provenance_training"]

LOG.error("Training provenance:\n%s", json.dumps(provenance_training, indent=2))

###########################################################################
def describe(self):
print("num_input_features:", self.num_input_features)
print("hour_steps:", self.hour_steps)
result = list(range(0, self.multi_step))
result = [-s * self.hour_steps for s in result]
print(sorted(result))
print("multi_step:", self.multi_step)
print()
print("MARS requests:")
print(dotted_line())
print("param_sfc:", self.param_sfc)
print("param_level_pl:", self.param_level_pl)
print("param_level_ml:", self.param_level_ml)
print("prognostic_params:", self.prognostic_params)
print("diagnostic_params:", self.diagnostic_params)
print("constants_from_input:", self.constants_from_input)
print("computed_constants:", self.computed_constants)
print("computed_forcings:", self.computed_forcings)

0 comments on commit 4ab2d62

Please sign in to comment.