Skip to content

Commit

Permalink
Feature/Add method to validate environment in a checkpoint (#13)
Browse files Browse the repository at this point in the history
* Feature: Add environment checking
- Add property for `provenance_training` in Metadata
- Check environment on `validate_environment`

* chore: Update changelog

* Fix: Git record difference check

* Fix: Cull categories with no messages

* Add PR Tag

* Make link

* Add exempt packages from the checking

* Change return value

* Add Version class

* Address comments
- Improve error messaging
- Get inference environment from pkg
- Allow patch comparison

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove self

* Replace custom version with packaging

* Fix exemption

* Replace relative imports

* Add more exempt packages
- Protect packake_exemptions.py in CODEOWNERS

* Address naming issues
- Attempt an import to fix basic python modules being missed
- Add ignore

* Add __future__ annotations

* Add --validate to inspect subcommand

* Changelog

---------

Co-authored-by: Gert Mertes <[email protected]>
  • Loading branch information
HCookie and gmertes authored Sep 25, 2024
1 parent 378e5f1 commit 13ea14f
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
/.github/ @theissenhelen @jesperdramsch @gmertes
/.pre-commit-config.yaml @theissenhelen @jesperdramsch @gmertes
/pyproject.toml @theissenhelen @jesperdramsch @gmertes

# Protect package exemptions
/src/anemoi/inference/checkpoint/package_exemptions.py @gmertes @hcookie @theissenhelen @jesperdramsch
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Keep it human-readable, your future self will thank you!
### Added
- ci: changelog release updater
- earthkit-data replaces climetlab
- `validate_environment` on Checkpoint [#13](https://github.com/ecmwf/anemoi-inference/pull/13)
- Validate the environment against a checkpoint with `anemoi-inference inspect --validate path.ckpt`
- ci-hpc-config
- Add Condition to store data [#15](https://github.com/ecmwf/anemoi-inference/pull/15)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies = [
"anytree",
"earthkit-data>=0.10",
"numpy",
"packaging",
"pyyaml",
"semantic-version",
"torch",
Expand Down
147 changes: 146 additions & 1 deletion src/anemoi/inference/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import json
import logging
import os
from functools import cached_property
from typing import Literal

from anemoi.utils.checkpoints import has_metadata
from anemoi.utils.checkpoints import load_metadata
from anemoi.utils.provenance import gather_provenance_info
from packaging.version import Version

from .metadata import Metadata
from anemoi.inference.checkpoint.metadata import Metadata
from anemoi.inference.checkpoint.package_exemptions import EXEMPT_NAMESPACES
from anemoi.inference.checkpoint.package_exemptions import EXEMPT_PACKAGES

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,3 +69,142 @@ def operational_config(self):

LOG.warning("No operational configuration found. Using default configuration.")
return {}

def validate_environment(
self,
all_packages: bool = False,
on_difference: Literal["warn", "error", "ignore"] = "warn",
*,
exempt_packages: list[str] | None = None,
) -> bool:
"""
Validate environment of the checkpoint against the current environment.
Parameters
----------
all_packages : bool, optional
Check all packages in environment or just `anemoi`'s, by default False
on_difference : Literal['warn', 'error', 'ignore'], optional
What to do on difference, by default "warn"
exempt_packages : list[str], optional
List of packages to exempt from the check, by default EXEMPT_PACKAGES
Returns
-------
bool
True if environment is valid, False otherwise
Raises
------
RuntimeError
If found difference and `on_difference` is 'error'
ValueError
If `on_difference` is not 'warn' or 'error'
"""
train_environment = self.provenance_training
inference_environment = gather_provenance_info(full=False)

# Override module information with more complete inference environment capture
import importlib.metadata as imp_metadata

module_versions = {
distribution.metadata["Name"].replace("-", "_"): distribution.metadata["Version"]
for distribution in imp_metadata.distributions()
}

inference_environment["module_versions"] = module_versions

exempt_packages = exempt_packages or []
exempt_packages.extend(EXEMPT_PACKAGES)

invalid_messages = {
"python": [],
"missing": [],
"mismatch": [],
"critical mismatch": [],
"uncommitted": [],
}

if train_environment["python"] != inference_environment["python"]:
invalid_messages["python"].append(
f"Python version mismatch: {train_environment['python']} != {inference_environment['python']}"
)

for module in train_environment["module_versions"].keys():
inference_module_name = module # Due to package name differences between retrieval methods this may change

if not all_packages and "anemoi" not in module:
continue
elif module in exempt_packages or module.split(".")[0] in EXEMPT_NAMESPACES:
continue
elif module.startswith("_"):
continue
elif module not in inference_environment["module_versions"]:
if "." in module and module.replace(".", "_") in inference_environment["module_versions"]:
inference_module_name = module.replace(".", "_")
else:
try:
import importlib

importlib.import_module(module)
continue
except (ModuleNotFoundError, ImportError):
pass
invalid_messages["missing"].append(f"Missing module in inference environment: {module}")
continue

train_environment_version = Version(train_environment["module_versions"][module])
inference_environment_version = Version(inference_environment["module_versions"][inference_module_name])

if train_environment_version < inference_environment_version:
invalid_messages["mismatch"].append(
f"Version of module {module} was lower in training then in inference: {train_environment_version!s} <= {inference_environment_version!s}"
)
elif train_environment_version > inference_environment_version:
invalid_messages["critical mismatch"].append(
f"CRITICAL: Version of module {module} was greater in training then in inference: {train_environment_version!s} > {inference_environment_version!s}"
)

for git_record in train_environment["git_versions"].keys():
file_record = train_environment["git_versions"][git_record]["git"]
if file_record["modified_files"] == 0 and file_record["untracked_files"] == 0:
continue

if git_record not in inference_environment["git_versions"]:
invalid_messages["uncommitted"].append(
f"Training environment contained uncommitted change missing in inference environment: {git_record}"
)
elif (
train_environment["git_versions"][git_record]["sha1"]
!= inference_environment["git_versions"][git_record]["sha1"]
):
invalid_messages["uncommitted"].append(
f"sha1 mismatch for git record between training and inference. {git_record} (training != inference): {train_environment['git_versions'][git_record]} != {inference_environment['git_versions'][git_record]}"
)

for git_record in inference_environment["git_versions"].keys():
file_record = inference_environment["git_versions"][git_record]["git"]
if file_record["modified_files"] == 0 and file_record["untracked_files"] == 0:
continue

if git_record not in train_environment["git_versions"]:
invalid_messages["uncommitted"].append(
f"Inference environment contains uncommited changes missing in training: {git_record}"
)

if len(invalid_messages) > 0:
text = "Environment validation failed. The following issues were found:\n" + "\n".join(
[f" {key}:\n " + "\n ".join(value) for key, value in invalid_messages.items() if len(value) > 0]
)
if on_difference == "warn":
LOG.warning(text)
elif on_difference == "error":
raise RuntimeError(text)
elif on_difference == "ignore":
pass
else:
raise ValueError(f"Invalid value for `on_difference`: {on_difference}")
return False

LOG.info(f"Environment validation passed")
return True
5 changes: 5 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def _config_training(self):
"""Part of the metadata refers to the model configuration"""
return self._metadata["config"]["training"]

@cached_property
def provenance_training(self):
"""Environmental Configuration when trained"""
return dict(self._metadata.get("provenance_training", {}))

###########################################################################
def _forcings(self, constants):
forcing = self._indices["data"]["input"]["forcing"]
Expand Down
15 changes: 15 additions & 0 deletions src/anemoi/inference/checkpoint/package_exemptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Complete package name to be exempt
EXEMPT_PACKAGES = [
"anemoi.training",
"hydra",
"hydra_plugins",
"lightning",
"pytorch_lightning",
"lightning_fabric",
"lightning_utilities",
]

# Entire namespaces to be exempt
EXEMPT_NAMESPACES = [
"hydra_plugins",
]
8 changes: 8 additions & 0 deletions src/anemoi/inference/commands/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,25 @@


class InspectCmd(Command):
"""Inspect the contents of a checkpoint file."""

need_logging = False

def add_arguments(self, command_parser):
command_parser.add_argument("path", help="Path to the checkpoint.")
command_parser.add_argument("--dump", action="store_true", help="Print internal information")
command_parser.add_argument(
"--validate", action="store_true", help="Validate the current virtual environment against the checkpoint"
)

def run(self, args):

c = Checkpoint(args.path)

if args.validate:
c.validate_environment()
return

if args.dump:
c.dump()
return
Expand Down

0 comments on commit 13ea14f

Please sign in to comment.