Skip to content

Commit

Permalink
Change outputs to plugin format
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Mar 25, 2024
1 parent 2d8f79b commit 961807f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _main(argv):
"--output",
default="file",
help="Where to output the results",
choices=available_outputs(),
choices=sorted(available_outputs()),
)

parser.add_argument(
Expand Down
14 changes: 6 additions & 8 deletions ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import cached_property

import climetlab as cml
import entrypoints
import numpy as np

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,18 +118,15 @@ def write(self, *args, **kwargs):
pass


OUTPUTS = dict(
file=FileOutput,
none=NoneOutput,
)


def get_output(name, owner, *args, **kwargs):
result = OUTPUTS[name](owner, *args, **kwargs)
result = available_outputs()[name].load()(owner, *args, **kwargs)
if kwargs.get("hindcast_reference_year") is not None:
result = HindcastReLabel(owner, result, **kwargs)
return result


def available_outputs():
return sorted(OUTPUTS.keys())
result = {}
for e in entrypoints.get_group_all("ai_models.output"):
result[e.name] = e
return result
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def read(fname):
"cds=ai_models.inputs:CdsInput",
"opendata=ai_models.inputs:OpenDataInput",
],
"ai_models.output": [
"file=ai_models.outputs:FileOutput",
"none=ai_models.outputs:NoneOutput",
],
},
classifiers=[
"Development Status :: 3 - Alpha",
Expand Down

0 comments on commit 961807f

Please sign in to comment.