diff --git a/ai_models/__main__.py b/ai_models/__main__.py index caad1df..e86c722 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -273,6 +273,7 @@ def _main(argv): def run(cfg: dict, model_args: list): if cfg["remote_execution"]: from .remote import RemoteModel + model = RemoteModel(**cfg, model_args=model_args) else: model = load_model(cfg["model"], **cfg, model_args=model_args) diff --git a/ai_models/remote.py b/ai_models/remote.py index 55af89d..1e54c5c 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -3,7 +3,7 @@ import sys import tempfile import time -from functools import cached_property +from functools import cache, cached_property from urllib.parse import urljoin import climetlab as cml @@ -17,12 +17,13 @@ class RemoteModel(Model): def __init__(self, **kwargs): - kwargs["download_assets"] = False - - super().__init__(**kwargs) - self.cfg = kwargs - self.client = RemoteClient() + self.cfg["download_assets"] = False + self.cfg["assets_extra_dir"] = None + self._param = {} + self.api = RemoteClient() + + super().__init__(**self.cfg) def run(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -30,10 +31,10 @@ def run(self): output_file = os.path.join(tmpdirname, "output.grib") self.all_fields.save(input_file) - self.client.input_file = input_file - self.client.output_file = output_file + self.api.input_file = input_file + self.api.output_file = output_file - self.client.run(self.cfg) + self.api.run(self.cfg) ds = cml.load_source("file", output_file) for field in ds: @@ -42,6 +43,41 @@ def run(self): def parse_model_args(self, args): return None + def __getattr__(self, name): + return self.get_param(name) + + @cache + def get_param(self, name): + return self.api.get_param(self.cfg["model"], name).get(name, None) + + @cached_property + def param_level_ml(self): + return self.get_param("param_level_ml") or ([], []) + + @cached_property + def param_level_pl(self): + return self.get_param("param_level_pl") or ([], []) + + @cached_property + def param_sfc(self): + return self.get_param("param_sfc") or [] + + @cached_property + def lagged(self): + return self.get_param("lagged") or False + + @cached_property + def version(self): + return self.get_param("version") or 1 + + @cached_property + def grib_extra_metadata(self): + return self.get_param("grib_extra_metadata") or {} + + @cached_property + def retrieve(self): + return self.get_param("retrieve") or {} + class BearerAuth(requests.auth.AuthBase): def __init__(self, token): @@ -133,17 +169,29 @@ def run(self, cfg: dict): LOG.debug("Result written to %s", self.output_file) - def _request(self, type, href, data=None, json=None, auth=None): - r = robust(type, retry_after=self._timeout)( + def get_param(self, model, param): + if isinstance(param, str): + return self._request( + requests.get, f"metadata/{model}/{param}", with_status=False + ) + else: + return self._request( + requests.post, f"metadata/{model}", json=param, with_status=False + ) + + def _request(self, type, href, data=None, json=None, auth=None, with_status=True): + r = robust(type, retry_after=30)( urljoin(self.url, href), json=json, data=data, auth=self.auth, timeout=self._timeout, ) - - status, href = self._update_state(r) - return status, href + if with_status: + status, href = self._update_state(r) + return status, href + else: + return r.json() def _update_state(self, response: requests.Response): if response.status_code == 401: