From a0c8f682889a126ad41cd5f90503885aa982e7d1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 28 May 2024 11:06:50 +0100 Subject: [PATCH] Work on cli --- pyproject.toml | 2 +- src/anemoi/inference/__main__.py | 59 +++----------------------------- src/anemoi/inference/runner.py | 43 ++++++++++++++++++++++- 3 files changed, 48 insertions(+), 56 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e468cf8..37f2d7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ dependencies = [ "tomli", # Only needed before 3.11 - "anemoi-utils>=0.2.0", + "anemoi-utils>=0.2.1", "semantic-version", "pyyaml", "torch", diff --git a/src/anemoi/inference/__main__.py b/src/anemoi/inference/__main__.py index 1dc285e..be940c2 100644 --- a/src/anemoi/inference/__main__.py +++ b/src/anemoi/inference/__main__.py @@ -8,69 +8,20 @@ # nor does it submit to any jurisdiction. # - -import argparse -import logging -import sys -import traceback +from anemoi.utils.cli import cli_main +from anemoi.utils.cli import make_parser from . import __version__ from .commands import COMMANDS -LOG = logging.getLogger(__name__) - +# For read-the-docs def create_parser(): - parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - - parser.add_argument( - "--version", - "-V", - action="store_true", - help="show the version and exit", - ) - parser.add_argument( - "--debug", - "-d", - action="store_true", - help="Debug mode", - ) - - subparsers = parser.add_subparsers(help="commands:", dest="command") - for name, command in COMMANDS.items(): - command_parser = subparsers.add_parser(name, help=command.__doc__) - command.add_arguments(command_parser) - - return parser + return make_parser(__doc__, COMMANDS) def main(): - parser = create_parser() - args = parser.parse_args() - - if args.version: - print(__version__) - return - - if args.command is None: - parser.print_help() - return - - cmd = COMMANDS[args.command] - - logging.basicConfig( - format="%(asctime)s %(levelname)s %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - level=logging.DEBUG if args.debug else logging.INFO, - ) - - try: - cmd.run(args) - except ValueError as e: - traceback.print_exc() - LOG.error("\nšŸ’£ %s", str(e).lstrip()) - LOG.error("šŸ’£ Exiting") - sys.exit(1) + cli_main(__version__, __doc__, COMMANDS) if __name__ == "__main__": diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index f3f339d..1e4a4c4 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -52,6 +52,7 @@ def ignore(*args, **kwargs): class Runner: + """_summary_""" def __init__(self, checkpoint): self.checkpoint = Checkpoint(checkpoint) @@ -68,7 +69,39 @@ def run( progress_callback=ignore, add_ensemble_dimension=False, ): - import torch + """_summary_ + + Parameters + ---------- + input_fields : _type_ + _description_ + lead_time : _type_ + _description_ + device : _type_ + _description_ + start_datetime : _type_, optional + _description_, by default None + output_callback : _type_, optional + _description_, by default ignore + autocast : _type_, optional + _description_, by default None + progress_callback : _type_, optional + _description_, by default ignore + add_ensemble_dimension : bool, optional + _description_, by default False + + Returns + ------- + _type_ + _description_ + + Raises + ------ + RuntimeError + _description_ + ValueError + _description_ + """ if autocast is None: autocast = self.checkpoint.precision @@ -374,4 +407,12 @@ def lagged(self): class DefaultRunner(Runner): + """_summary_ + + Parameters + ---------- + Runner : _type_ + _description_ + """ + pass