From ec969b66a06e3e5c2f8c041108c8c42c4d39e0fd Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 12 Mar 2024 14:11:17 +0000 Subject: [PATCH] Add RemoteModel --- ai_models/__main__.py | 11 ++++++----- ai_models/model.py | 18 ------------------ ai_models/remote.py | 38 ++++++++++++++++++++++++++++++++++---- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/ai_models/__main__.py b/ai_models/__main__.py index 148910f..caad1df 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -271,7 +271,11 @@ def _main(argv): def run(cfg: dict, model_args: list): - model = load_model(cfg["model"], **cfg, model_args=model_args) + if cfg["remote_execution"]: + from .remote import RemoteModel + model = RemoteModel(**cfg, model_args=model_args) + else: + model = load_model(cfg["model"], **cfg, model_args=model_args) if cfg["fields"]: model.print_fields() @@ -289,10 +293,7 @@ def run(cfg: dict, model_args: list): sys.exit(0) try: - if cfg["remote_execution"]: - model.remote(cfg, model_args) - else: - model.run() + model.run() except FileNotFoundError as e: LOG.exception(e) LOG.error( diff --git a/ai_models/model.py b/ai_models/model.py index 3f81be9..636abe8 100644 --- a/ai_models/model.py +++ b/ai_models/model.py @@ -24,7 +24,6 @@ from .checkpoint import peek from .inputs import get_input from .outputs import get_output -from .remote import RemoteClient from .stepper import Stepper LOG = logging.getLogger(__name__) @@ -458,23 +457,6 @@ def write_input_fields( check=True, ) - def remote(self, cfg: dict, model_args: list): - with tempfile.TemporaryDirectory() as tmpdirname: - input_file = os.path.join(tmpdirname, "input.grib") - output_file = os.path.join(tmpdirname, "output.grib") - self.all_fields.save(input_file) - - client = RemoteClient( - input_file=input_file, - output_file=output_file, - ) - - client.run(cfg, model_args) - - ds = cml.load_source("file", output_file) - for field in ds: - self.write(None, template=field) - def load_model(name, **kwargs): return available_models()[name].load()(**kwargs) diff --git a/ai_models/remote.py b/ai_models/remote.py index e047bd0..55af89d 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -1,15 +1,48 @@ import logging import os import sys +import tempfile import time +from functools import cached_property from urllib.parse import urljoin +import climetlab as cml import requests from multiurl import download, robust +from .model import Model + LOG = logging.getLogger(__name__) +class RemoteModel(Model): + def __init__(self, **kwargs): + kwargs["download_assets"] = False + + super().__init__(**kwargs) + + self.cfg = kwargs + self.client = RemoteClient() + + def run(self): + with tempfile.TemporaryDirectory() as tmpdirname: + input_file = os.path.join(tmpdirname, "input.grib") + output_file = os.path.join(tmpdirname, "output.grib") + self.all_fields.save(input_file) + + self.client.input_file = input_file + self.client.output_file = output_file + + self.client.run(self.cfg) + + ds = cml.load_source("file", output_file) + for field in ds: + self.write(None, template=field) + + def parse_model_args(self, args): + return None + + class BearerAuth(requests.auth.AuthBase): def __init__(self, token): self.token = token @@ -61,10 +94,7 @@ def __init__( self.input_file = input_file self._timeout = 300 - def run(self, cfg: dict, model_args: list): - cfg.pop("remote_execution", None) - cfg["model_args"] = model_args - + def run(self, cfg: dict): # upload file with open(self.input_file, "rb") as f: LOG.info("Uploading input file to remote")