From c6a2b80453315869fb862f254a3700f2b9dcd5f0 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 26 Mar 2024 14:10:14 +0000 Subject: [PATCH] Refactor remote.py to remote subpackage --- ai_models/remote/__init__.py | 4 + ai_models/{remote.py => remote/api.py} | 99 ------------------------ ai_models/remote/model.py | 102 +++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 99 deletions(-) create mode 100644 ai_models/remote/__init__.py rename ai_models/{remote.py => remote/api.py} (67%) create mode 100644 ai_models/remote/model.py diff --git a/ai_models/remote/__init__.py b/ai_models/remote/__init__.py new file mode 100644 index 0000000..0fc07af --- /dev/null +++ b/ai_models/remote/__init__.py @@ -0,0 +1,4 @@ +from .api import RemoteAPI +from .model import RemoteModel + +_all__ = ["RemoteAPI", "RemoteModel"] diff --git a/ai_models/remote.py b/ai_models/remote/api.py similarity index 67% rename from ai_models/remote.py rename to ai_models/remote/api.py index 48493d0..d14412a 100644 --- a/ai_models/remote.py +++ b/ai_models/remote/api.py @@ -1,115 +1,16 @@ import logging import os import sys -import tempfile import time -from functools import cached_property from urllib.parse import urljoin -import climetlab as cml import requests from multiurl import download, robust from tqdm import tqdm -from .model import Model - LOG = logging.getLogger(__name__) -class RemoteModel(Model): - def __init__(self, **kwargs): - self.cfg = kwargs - self.cfg["download_assets"] = False - self.cfg["assets_extra_dir"] = None - - self.model = self.cfg["model"] - self._param = {} - self.api = RemoteAPI() - - self.load_parameters() - - super().__init__(**self.cfg) - - def __getattr__(self, name): - return self.get_parameter(name) - - def run(self): - with tempfile.TemporaryDirectory() as tmpdirname: - input_file = os.path.join(tmpdirname, "input.grib") - output_file = os.path.join(tmpdirname, "output.grib") - self.all_fields.save(input_file) - - self.api.input_file = input_file - self.api.output_file = output_file - - self.api.run(self.cfg) - - ds = cml.load_source("file", output_file) - for field in ds: - self.write(None, template=field) - - def parse_model_args(self, args): - return None - - def patch_retrieve_request(self, request): - patched = self.api.patch_retrieve_request(self.cfg, request) - request.update(patched) - - def load_parameters(self): - params = self.api.metadata( - self.model, - [ - "expver", - "version", - "grid", - "area", - "param_level_ml", - "param_level_pl", - "param_sfc", - "lagged", - "grib_extra_metadata", - "retrieve", - ], - ) - self._param.update(params) - - def get_parameter(self, name): - if (param := self._param.get(name, None)) is not None: - return param - - self._param.update(self.api.metadata(self.model, name)) - - return self._param[name] - - @cached_property - def param_level_ml(self): - return self.get_parameter("param_level_ml") or ([], []) - - @cached_property - def param_level_pl(self): - return self.get_parameter("param_level_pl") or ([], []) - - @cached_property - def param_sfc(self): - return self.get_parameter("param_sfc") or [] - - @cached_property - def lagged(self): - return self.get_parameter("lagged") or False - - @cached_property - def version(self): - return self.get_parameter("version") or 1 - - @cached_property - def grib_extra_metadata(self): - return self.get_parameter("grib_extra_metadata") or {} - - @cached_property - def retrieve(self): - return self.get_parameter("retrieve") or {} - - class BearerAuth(requests.auth.AuthBase): def __init__(self, token): self.token = token diff --git a/ai_models/remote/model.py b/ai_models/remote/model.py new file mode 100644 index 0000000..7fdd8ee --- /dev/null +++ b/ai_models/remote/model.py @@ -0,0 +1,102 @@ +import os +import tempfile +from functools import cached_property + +import climetlab as cml + +from ..model import Model +from .api import RemoteAPI + + +class RemoteModel(Model): + def __init__(self, **kwargs): + self.cfg = kwargs + self.cfg["download_assets"] = False + self.cfg["assets_extra_dir"] = None + + self.model = self.cfg["model"] + self._param = {} + self.api = RemoteAPI() + + self.load_parameters() + + super().__init__(**self.cfg) + + def __getattr__(self, name): + return self.get_parameter(name) + + def run(self): + with tempfile.TemporaryDirectory() as tmpdirname: + input_file = os.path.join(tmpdirname, "input.grib") + output_file = os.path.join(tmpdirname, "output.grib") + self.all_fields.save(input_file) + + self.api.input_file = input_file + self.api.output_file = output_file + + self.api.run(self.cfg) + + ds = cml.load_source("file", output_file) + for field in ds: + self.write(None, template=field) + + def parse_model_args(self, args): + return None + + def patch_retrieve_request(self, request): + patched = self.api.patch_retrieve_request(self.cfg, request) + request.update(patched) + + def load_parameters(self): + params = self.api.metadata( + self.model, + [ + "expver", + "version", + "grid", + "area", + "param_level_ml", + "param_level_pl", + "param_sfc", + "lagged", + "grib_extra_metadata", + "retrieve", + ], + ) + self._param.update(params) + + def get_parameter(self, name): + if (param := self._param.get(name, None)) is not None: + return param + + self._param.update(self.api.metadata(self.model, name)) + + return self._param[name] + + @cached_property + def param_level_ml(self): + return self.get_parameter("param_level_ml") or ([], []) + + @cached_property + def param_level_pl(self): + return self.get_parameter("param_level_pl") or ([], []) + + @cached_property + def param_sfc(self): + return self.get_parameter("param_sfc") or [] + + @cached_property + def lagged(self): + return self.get_parameter("lagged") or False + + @cached_property + def version(self): + return self.get_parameter("version") or 1 + + @cached_property + def grib_extra_metadata(self): + return self.get_parameter("grib_extra_metadata") or {} + + @cached_property + def retrieve(self): + return self.get_parameter("retrieve") or {}