From 21cf2f007ac22902231b018192b67257e398e244 Mon Sep 17 00:00:00 2001 From: Robin Andersson Date: Tue, 11 Jun 2024 15:45:03 +0200 Subject: [PATCH] [HWORKS-1048] Support multiple modularized project environments --- python/hsml/deployment.py | 9 +++++++ python/hsml/engine/model_engine.py | 9 ++----- python/hsml/model.py | 3 +++ python/hsml/model_serving.py | 5 ++-- python/hsml/predictor.py | 16 +++++++++++ python/tests/fixtures/predictor_fixtures.json | 27 +++++++++++++++++++ python/tests/test_model.py | 3 ++- 7 files changed, 62 insertions(+), 10 deletions(-) diff --git a/python/hsml/deployment.py b/python/hsml/deployment.py index 473c3b9c..0336d120 100644 --- a/python/hsml/deployment.py +++ b/python/hsml/deployment.py @@ -461,6 +461,15 @@ def api_protocol(self): def api_protocol(self, api_protocol: str): self._predictor.api_protocol = api_protocol + @property + def environment(self): + """Name of inference environment""" + return self._predictor.environment + + @environment.setter + def environment(self, environment: str): + self._predictor.environment = environment + def __repr__(self): desc = ( f", description: {self._description!r}" diff --git a/python/hsml/engine/model_engine.py b/python/hsml/engine/model_engine.py index 29acd269..b4b4090a 100644 --- a/python/hsml/engine/model_engine.py +++ b/python/hsml/engine/model_engine.py @@ -14,7 +14,6 @@ # limitations under the License. # -import importlib import json import os import tempfile @@ -24,7 +23,7 @@ from hsml import client, constants, util from hsml.client.exceptions import ModelRegistryException, RestAPIError from hsml.core import dataset_api, model_api -from hsml.engine import hopsworks_engine, local_engine +from hsml.engine import local_engine from tqdm.auto import tqdm @@ -33,11 +32,7 @@ def __init__(self): self._model_api = model_api.ModelApi() self._dataset_api = dataset_api.DatasetApi() - pydoop_spec = importlib.util.find_spec("pydoop") - if pydoop_spec is None: - self._engine = local_engine.LocalEngine() - else: - self._engine = hopsworks_engine.HopsworksEngine() + self._engine = local_engine.LocalEngine() def _poll_model_available(self, model_instance, await_registration): if await_registration > 0: diff --git a/python/hsml/model.py b/python/hsml/model.py index 2d63a7ee..e6147d5f 100644 --- a/python/hsml/model.py +++ b/python/hsml/model.py @@ -173,6 +173,7 @@ def deploy( inference_batcher: Optional[Union[InferenceBatcher, dict]] = None, transformer: Optional[Union[Transformer, dict]] = None, api_protocol: Optional[str] = IE.API_PROTOCOL_REST, + environment: Optional[str] = None, ): """Deploy the model. @@ -203,6 +204,7 @@ def deploy( inference_batcher: Inference batcher configuration. transformer: Transformer to be deployed together with the predictor. api_protocol: API protocol to be enabled in the deployment (i.e., 'REST' or 'GRPC'). Defaults to 'REST'. + environment: The inference environment to use. # Returns `Deployment`: The deployment metadata object of a new or existing deployment. @@ -223,6 +225,7 @@ def deploy( inference_batcher=inference_batcher, transformer=transformer, api_protocol=api_protocol, + environment=environment, ) return predictor.deploy() diff --git a/python/hsml/model_serving.py b/python/hsml/model_serving.py index a256fdc1..21d04b83 100644 --- a/python/hsml/model_serving.py +++ b/python/hsml/model_serving.py @@ -285,7 +285,7 @@ def postprocess(self, outputs): return Transformer(script_file=script_file, resources=resources) - def create_deployment(self, predictor: Predictor, name: Optional[str] = None): + def create_deployment(self, predictor: Predictor, name: Optional[str] = None, environment: Optional[str] = None): """Create a Deployment metadata object. !!! example @@ -348,12 +348,13 @@ def create_deployment(self, predictor: Predictor, name: Optional[str] = None): # Arguments predictor: predictor to be used in the deployment name: name of the deployment + environment: The inference environment to use # Returns `Deployment`. The model metadata object. """ - return Deployment(predictor=predictor, name=name) + return Deployment(predictor=predictor, name=name, environment=environment) @property def project_name(self): diff --git a/python/hsml/predictor.py b/python/hsml/predictor.py index 10cc29f4..bf7d7c70 100644 --- a/python/hsml/predictor.py +++ b/python/hsml/predictor.py @@ -56,6 +56,7 @@ def __init__( created_at: Optional[str] = None, creator: Optional[str] = None, api_protocol: Optional[str] = INFERENCE_ENDPOINTS.API_PROTOCOL_REST, + environment: Optional[str] = None, **kwargs, ): serving_tool = ( @@ -91,6 +92,7 @@ def __init__( self._transformer = util.get_obj_from_json(transformer, Transformer) self._validate_script_file(self._model_framework, self._script_file) self._api_protocol = api_protocol + self._environment = environment def deploy(self): """Create a deployment for this predictor and persists it in the Model Serving. @@ -268,6 +270,9 @@ def extract_fields_from_json(cls, json_decamelized): kwargs["created_at"] = json_decamelized.pop("created") kwargs["creator"] = json_decamelized.pop("creator") kwargs["api_protocol"] = json_decamelized.pop("api_protocol") + if "environmentdto" in json_decamelized: + environment = json_decamelized.pop("environmentdto") + kwargs["environment"] = environment["name"] return kwargs def update_from_response_json(self, json_dict): @@ -296,6 +301,8 @@ def to_dict(self): "predictor": self._script_file, "apiProtocol": self._api_protocol, } + if self.environment is not None: + json = {**json, **{"environmentDTO": {"name": self._environment}}} if self._resources is not None: json = {**json, **self._resources.to_dict()} if self._inference_logger is not None: @@ -457,6 +464,15 @@ def api_protocol(self): def api_protocol(self, api_protocol): self._api_protocol = api_protocol + @property + def environment(self): + """Name of the inference environment""" + return self._environment + + @environment.setter + def environment(self, environment): + self._environment = environment + def __repr__(self): desc = ( f", description: {self._description!r}" diff --git a/python/tests/fixtures/predictor_fixtures.json b/python/tests/fixtures/predictor_fixtures.json index b0b7b2fc..76adeebe 100644 --- a/python/tests/fixtures/predictor_fixtures.json +++ b/python/tests/fixtures/predictor_fixtures.json @@ -40,6 +40,9 @@ "inference_logging": "ALL", "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "misc-inference-pipeline" } } ] @@ -92,6 +95,9 @@ "inference_logging": "ALL", "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "misc-inference-pipeline" } }, { @@ -131,6 +137,9 @@ "inference_logging": "ALL", "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "misc-inference-pipeline" } } ] @@ -160,6 +169,9 @@ "inference_logging": "ALL", "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "tensorflow-inference-pipeline" } } }, @@ -200,6 +212,9 @@ "inference_logging": "ALL", "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "tensorflow-inference-pipeline" } } }, @@ -235,6 +250,9 @@ "inference_logging": "ALL", "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "misc-inference-pipeline" } } }, @@ -277,6 +295,9 @@ }, "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "misc-inference-pipeline" } } }, @@ -312,6 +333,9 @@ }, "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "misc-inference-pipeline" } } }, @@ -354,6 +378,9 @@ }, "kafka_topic_dto": { "name": "topic" + }, + "environment_dto": { + "name": "misc-inference-pipeline" } } }, diff --git a/python/tests/test_model.py b/python/tests/test_model.py index 31757c06..92b6e5a9 100644 --- a/python/tests/test_model.py +++ b/python/tests/test_model.py @@ -202,7 +202,7 @@ def test_deploy(self, mocker, backend_fixtures): inference_logger=inference_logger, inference_batcher=inference_batcher, transformer=transformer, - api_protocol=p_json["api_protocol"], + environment=p_json["environment_dto"]["name"], ) # Assert @@ -218,6 +218,7 @@ def test_deploy(self, mocker, backend_fixtures): inference_batcher=inference_batcher, transformer=transformer, api_protocol=p_json["api_protocol"], + environment=p_json["environment_dto"]["name"], ) mock_predictor.deploy.assert_called_once()