Skip to content

Commit

Permalink
Work on cli
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 28, 2024
1 parent be95203 commit a0c8f68
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 56 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ classifiers = [

dependencies = [
"tomli", # Only needed before 3.11
"anemoi-utils>=0.2.0",
"anemoi-utils>=0.2.1",
"semantic-version",
"pyyaml",
"torch",
Expand Down
59 changes: 5 additions & 54 deletions src/anemoi/inference/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,20 @@
# nor does it submit to any jurisdiction.
#


import argparse
import logging
import sys
import traceback
from anemoi.utils.cli import cli_main
from anemoi.utils.cli import make_parser

from . import __version__
from .commands import COMMANDS

LOG = logging.getLogger(__name__)


# For read-the-docs
def create_parser():
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)

parser.add_argument(
"--version",
"-V",
action="store_true",
help="show the version and exit",
)
parser.add_argument(
"--debug",
"-d",
action="store_true",
help="Debug mode",
)

subparsers = parser.add_subparsers(help="commands:", dest="command")
for name, command in COMMANDS.items():
command_parser = subparsers.add_parser(name, help=command.__doc__)
command.add_arguments(command_parser)

return parser
return make_parser(__doc__, COMMANDS)


def main():
parser = create_parser()
args = parser.parse_args()

if args.version:
print(__version__)
return

if args.command is None:
parser.print_help()
return

cmd = COMMANDS[args.command]

logging.basicConfig(
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.DEBUG if args.debug else logging.INFO,
)

try:
cmd.run(args)
except ValueError as e:
traceback.print_exc()
LOG.error("\n💣 %s", str(e).lstrip())
LOG.error("💣 Exiting")
sys.exit(1)
cli_main(__version__, __doc__, COMMANDS)


if __name__ == "__main__":
Expand Down
43 changes: 42 additions & 1 deletion src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def ignore(*args, **kwargs):


class Runner:
"""_summary_"""

def __init__(self, checkpoint):
self.checkpoint = Checkpoint(checkpoint)
Expand All @@ -68,7 +69,39 @@ def run(
progress_callback=ignore,
add_ensemble_dimension=False,
):
import torch
"""_summary_
Parameters
----------
input_fields : _type_
_description_
lead_time : _type_
_description_
device : _type_
_description_
start_datetime : _type_, optional
_description_, by default None
output_callback : _type_, optional
_description_, by default ignore
autocast : _type_, optional
_description_, by default None
progress_callback : _type_, optional
_description_, by default ignore
add_ensemble_dimension : bool, optional
_description_, by default False
Returns
-------
_type_
_description_
Raises
------
RuntimeError
_description_
ValueError
_description_
"""

if autocast is None:
autocast = self.checkpoint.precision
Expand Down Expand Up @@ -374,4 +407,12 @@ def lagged(self):


class DefaultRunner(Runner):
"""_summary_
Parameters
----------
Runner : _type_
_description_
"""

pass

0 comments on commit a0c8f68

Please sign in to comment.