From be952033ce52da460d306e71fcad6f857cd6dda4 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 28 May 2024 10:32:28 +0100 Subject: [PATCH] update --- pyproject.toml | 1 + src/anemoi/inference/__main__.py | 7 ++++++- src/anemoi/inference/runner.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 53a1599..e468cf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "anemoi-utils>=0.2.0", "semantic-version", "pyyaml", + "torch", ] [project.optional-dependencies] diff --git a/src/anemoi/inference/__main__.py b/src/anemoi/inference/__main__.py index 2ac1a15..1dc285e 100644 --- a/src/anemoi/inference/__main__.py +++ b/src/anemoi/inference/__main__.py @@ -20,7 +20,7 @@ LOG = logging.getLogger(__name__) -def main(): +def create_parser(): parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument( @@ -41,6 +41,11 @@ def main(): command_parser = subparsers.add_parser(name, help=command.__doc__) command.add_arguments(command_parser) + return parser + + +def main(): + parser = create_parser() args = parser.parse_args() if args.version: diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 34044ae..f3f339d 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -68,6 +68,7 @@ def run( progress_callback=ignore, add_ensemble_dimension=False, ): + import torch if autocast is None: autocast = self.checkpoint.precision