Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide options to patch the checkpoint metadata when running a model #73

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,26 @@
class Checkpoint:
"""Represents an inference checkpoint."""

def __init__(self, path):
def __init__(self, path, *, patch_metadata=None):
self.path = path
self.patch_metadata = patch_metadata

def __repr__(self):
return f"Checkpoint({self.path})"

@cached_property
def _metadata(self):
try:
return Metadata(*load_metadata(self.path, supporting_arrays=True))
result = Metadata(*load_metadata(self.path, supporting_arrays=True))
except Exception as e:
LOG.warning("Version for not support `supporting_arrays` (%s)", e)
return Metadata(load_metadata(self.path))
result = Metadata(load_metadata(self.path))

if self.patch_metadata:
LOG.warning("Patching metadata with %r", self.patch_metadata)
result.patch(self.patch_metadata)

return result

###########################################################################
# Forwards used by the runner
Expand Down
43 changes: 34 additions & 9 deletions src/anemoi/inference/commands/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def add_arguments(self, command_parser):
),
)

group.add_argument(
"--view",
action="store_true",
help=(
"View the metadata in place, using the specified pager."
" See the ``--pager`` argument for more information."
),
)

group.add_argument(
"--remove",
action="store_true",
Expand Down Expand Up @@ -91,22 +100,31 @@ def add_arguments(self, command_parser):
default=os.environ.get("EDITOR", "vi"),
)

command_parser.add_argument(
"--pager",
help="Editor to use for the ``--view`` option. Default to ``$PAGER`` if defined, else ``less``.",
default=os.environ.get("PAGER", "less"),
b8raoult marked this conversation as resolved.
Show resolved Hide resolved
)

command_parser.add_argument(
"--json",
action="store_true",
help="Use the JSON format with ``--dump`` and ``--edit``.",
help="Use the JSON format with ``--dump``, ``--view`` and ``--edit``.",
)

command_parser.add_argument(
"--yaml",
action="store_true",
help="Use the YAML format with ``--dump`` and ``--edit``.",
help="Use the YAML format with ``--dump``, ``--view`` and ``--edit``.",
)

def run(self, args):
if args.edit:
return self.edit(args)

if args.view:
return self.view(args)

if args.remove:
return self.remove(args)

Expand All @@ -120,6 +138,12 @@ def run(self, args):
return self.supporting_arrays(args)

def edit(self, args):
return self._edit(args, view=False, cmd=args.editor)

def view(self, args):
return self._edit(args, view=True, cmd=args.pager)

def _edit(self, args, view, cmd):

from anemoi.utils.checkpoints import load_metadata
from anemoi.utils.checkpoints import replace_metadata
Expand All @@ -143,15 +167,16 @@ def edit(self, args):
with open(path, "w") as f:
dump(metadata, f, **kwargs)

subprocess.check_call([args.editor, path])
subprocess.check_call([cmd, path])

with open(path) as f:
edited = load(f)
if not view:
with open(path) as f:
edited = load(f)

if edited != metadata:
replace_metadata(args.path, edited)
else:
LOG.info("No changes made.")
if edited != metadata:
replace_metadata(args.path, edited)
else:
LOG.info("No changes made.")

def remove(self, args):
from anemoi.utils.checkpoints import remove_metadata
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class Config:
such as `eccodes`. In certain cases, the variables mey be set too late, if the package for which they are intended
is already loaded when the runner is configured."""

patch_metadata: dict = {}
"""A dictionary of metadata to patch the checkpoint metadata with. This is used to test new features or to work around
issues with the checkpoint metadata."""

development_hacks: dict = {}
"""A dictionary of development hacks to apply to the runner. This is used to test new features or to work around"""

Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/inference/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,6 @@ def _legacy_data_request(self):
if len(checks[c]) > 1:
warnings.warn(f"{c} is ambigous: {checks[c]}")

result = [r for r in result if r["grid"] is not None]

return result[0]
17 changes: 17 additions & 0 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,23 @@ def print_variable_categories(self):
for name, categories in sorted(self.variable_categories().items()):
LOG.info(f" {name:{length}} => {', '.join(categories)}")

###########################################################################

def patch(self, patch):
"""Patch the metadata with the given patch"""

def merge(main, patch):

for k, v in patch.items():
if isinstance(v, dict):
if k not in main:
main[k] = {}
merge(main[k], v)
else:
main[k] = v

merge(self._metadata, patch)


class SourceMetadata(Metadata):
"""An object that holds metadata of a source. It is only the `dataset` and `supporting_arrays` parts of the metadata.
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


class ArchiveCollector:
"""Collects archive requests"""

UNIQUE = {"date", "hdate", "time", "referenceDate", "type", "stream", "expver"}

def __init__(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def __init__(
use_grib_paramid=False,
verbosity=0,
inference_options=None,
patch_metadata={},
development_hacks={}, # For testing purposes, don't use in production
):
self._checkpoint = Checkpoint(checkpoint)
self._checkpoint = Checkpoint(checkpoint, patch_metadata=patch_metadata)

self.device = device
self.precision = precision
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/inference/runners/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, config):
verbosity=config.verbosity,
report_error=config.report_error,
use_grib_paramid=config.use_grib_paramid,
patch_metadata=config.patch_metadata,
development_hacks=config.development_hacks,
)

Expand Down
Loading