Skip to content

Commit

Permalink
Add RemoteModel
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Mar 12, 2024
1 parent 08ef357 commit ec969b6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 27 deletions.
11 changes: 6 additions & 5 deletions ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,11 @@ def _main(argv):


def run(cfg: dict, model_args: list):
model = load_model(cfg["model"], **cfg, model_args=model_args)
if cfg["remote_execution"]:
from .remote import RemoteModel
model = RemoteModel(**cfg, model_args=model_args)
else:
model = load_model(cfg["model"], **cfg, model_args=model_args)

if cfg["fields"]:
model.print_fields()
Expand All @@ -289,10 +293,7 @@ def run(cfg: dict, model_args: list):
sys.exit(0)

try:
if cfg["remote_execution"]:
model.remote(cfg, model_args)
else:
model.run()
model.run()
except FileNotFoundError as e:
LOG.exception(e)
LOG.error(
Expand Down
18 changes: 0 additions & 18 deletions ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from .checkpoint import peek
from .inputs import get_input
from .outputs import get_output
from .remote import RemoteClient
from .stepper import Stepper

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -458,23 +457,6 @@ def write_input_fields(
check=True,
)

def remote(self, cfg: dict, model_args: list):
with tempfile.TemporaryDirectory() as tmpdirname:
input_file = os.path.join(tmpdirname, "input.grib")
output_file = os.path.join(tmpdirname, "output.grib")
self.all_fields.save(input_file)

client = RemoteClient(
input_file=input_file,
output_file=output_file,
)

client.run(cfg, model_args)

ds = cml.load_source("file", output_file)
for field in ds:
self.write(None, template=field)


def load_model(name, **kwargs):
return available_models()[name].load()(**kwargs)
Expand Down
38 changes: 34 additions & 4 deletions ai_models/remote.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,48 @@
import logging
import os
import sys
import tempfile
import time
from functools import cached_property
from urllib.parse import urljoin

import climetlab as cml
import requests
from multiurl import download, robust

from .model import Model

LOG = logging.getLogger(__name__)


class RemoteModel(Model):
def __init__(self, **kwargs):
kwargs["download_assets"] = False

super().__init__(**kwargs)

self.cfg = kwargs
self.client = RemoteClient()

def run(self):
with tempfile.TemporaryDirectory() as tmpdirname:
input_file = os.path.join(tmpdirname, "input.grib")
output_file = os.path.join(tmpdirname, "output.grib")
self.all_fields.save(input_file)

self.client.input_file = input_file
self.client.output_file = output_file

self.client.run(self.cfg)

ds = cml.load_source("file", output_file)
for field in ds:
self.write(None, template=field)

def parse_model_args(self, args):
return None


class BearerAuth(requests.auth.AuthBase):
def __init__(self, token):
self.token = token
Expand Down Expand Up @@ -61,10 +94,7 @@ def __init__(
self.input_file = input_file
self._timeout = 300

def run(self, cfg: dict, model_args: list):
cfg.pop("remote_execution", None)
cfg["model_args"] = model_args

def run(self, cfg: dict):
# upload file
with open(self.input_file, "rb") as f:
LOG.info("Uploading input file to remote")
Expand Down

0 comments on commit ec969b6

Please sign in to comment.