From a7835a4c823a76bf8b5e4cd4d7f5b816edd596b9 Mon Sep 17 00:00:00 2001 From: Jesse Claven Date: Wed, 16 Aug 2023 10:09:42 +0100 Subject: [PATCH] feat(huggingface): Merge model settings If settings for the model are provided in both the `model-settings.json` *and* environment variables, merge them. Precedence is given to environment variables. --- .../mlserver_huggingface/settings.py | 55 +++++- runtimes/huggingface/tests/test_settings.py | 168 ++++++++++++++++++ 2 files changed, 214 insertions(+), 9 deletions(-) create mode 100644 runtimes/huggingface/tests/test_settings.py diff --git a/runtimes/huggingface/mlserver_huggingface/settings.py b/runtimes/huggingface/mlserver_huggingface/settings.py index 0c3087b5a..b70c1c453 100644 --- a/runtimes/huggingface/mlserver_huggingface/settings.py +++ b/runtimes/huggingface/mlserver_huggingface/settings.py @@ -1,7 +1,7 @@ import os import orjson -from typing import Optional, Dict +from typing import Optional, Dict, Union from pydantic import BaseSettings from distutils.util import strtobool from transformers.pipelines import SUPPORTED_TASKS @@ -110,14 +110,22 @@ def task_name(self): return self.task -def parse_parameters_from_env() -> Dict: +def parse_parameters_from_env() -> Dict[str, Union[str, bool, float, int]]: """ This method parses the environment variables injected via SCv1. + + At least an empty dict is always returned. """ # TODO: Once support for SCv1 is deprecated, we should remove this method and rely # purely on settings coming via the `model-settings.json` file. parameters = orjson.loads(os.environ.get(PARAMETERS_ENV_NAME, "[]")) + parsed_parameters: Dict[str, Union[str, bool, float, int]] = {} + + # Guard: Exit early if there's no parameters + if len(parameters) == 0: + return parsed_parameters + type_dict = { "INT": int, "FLOAT": float, @@ -126,7 +134,6 @@ def parse_parameters_from_env() -> Dict: "BOOL": bool, } - parsed_parameters = {} for param in parameters: name = param.get("name") value = param.get("value") @@ -140,17 +147,15 @@ def parse_parameters_from_env() -> Dict: raise InvalidModelParameter(name, value, type_) except KeyError: raise InvalidModelParameterType(type_) + return parsed_parameters def get_huggingface_settings(model_settings: ModelSettings) -> HuggingFaceSettings: - env_params = parse_parameters_from_env() - if not env_params and ( - not model_settings.parameters or not model_settings.parameters.extra - ): - raise MissingHuggingFaceSettings() + """Get the HuggingFace settings provided to the runtime""" - extra = env_params or model_settings.parameters.extra # type: ignore + env_params = parse_parameters_from_env() + extra = merge_huggingface_settings_extra(model_settings, env_params) hf_settings = HuggingFaceSettings(**extra) # type: ignore if hf_settings.task not in SUPPORTED_TASKS: @@ -161,3 +166,35 @@ def get_huggingface_settings(model_settings: ModelSettings) -> HuggingFaceSettin raise InvalidOptimumTask(hf_settings.task, SUPPORTED_OPTIMUM_TASKS.keys()) return hf_settings + + +def merge_huggingface_settings_extra( + model_settings: ModelSettings, env_params: Dict[str, Union[str, bool, float, int]] +) -> Dict[str, Union[str, bool, float, int]]: + """ + This function returns the Extra field of the Settings. + + It merges them, iff they're both present, from the + environment AND model settings file. Precedence is + giving to the environment. + """ + + # Both `parameters` and `extra` are Optional, so we + # need to get the value, or nothing. + settings_params = ( + model_settings.parameters.extra + if model_settings.parameters is not None + else None + ) + + if settings_params is None and env_params == {}: + # There must be settings provided by at least the environment OR model settings + raise MissingHuggingFaceSettings() + + # Set the default value + settings_params = settings_params or {} + + # Overwrite any conflicting keys, giving precedence to the environment + settings_params.update(env_params) + + return settings_params diff --git a/runtimes/huggingface/tests/test_settings.py b/runtimes/huggingface/tests/test_settings.py new file mode 100644 index 000000000..13fd95205 --- /dev/null +++ b/runtimes/huggingface/tests/test_settings.py @@ -0,0 +1,168 @@ +import pytest + +from typing import Dict + +from mlserver.settings import ModelSettings, ModelParameters + +from mlserver_huggingface.runtime import HuggingFaceRuntime +from mlserver_huggingface.settings import ( + HuggingFaceSettings, + PARAMETERS_ENV_NAME, + get_huggingface_settings, + merge_huggingface_settings_extra, +) +from mlserver_huggingface.errors import MissingHuggingFaceSettings + + +@pytest.fixture() +def model_settings_extra_task(): + return ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + parameters=ModelParameters( + extra={"task": "text-generation", "pretrained_model": "distilgpt2"} + ), + ) + + +@pytest.fixture() +def model_settings_extra_none(): + return ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + parameters=ModelParameters(extra=None), + ) + + +@pytest.fixture() +def model_settings_extra_empty(): + return ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + parameters=ModelParameters(extra={}), + ) + + +@pytest.mark.parametrize( + "model_settings,env_params,expected", + [ + ( + "model_settings_extra_task", + {"task": "question-answering"}, + {"task": "question-answering", "pretrained_model": "distilgpt2"}, + ), + ( + "model_settings_extra_task", + {}, + {"task": "text-generation", "pretrained_model": "distilgpt2"}, + ), + ( + "model_settings_extra_none", + {"task": "question-answering"}, + {"task": "question-answering"}, + ), + ( + "model_settings_extra_empty", + {"task": "question-answering"}, + {"task": "question-answering"}, + ), + ], +) +def test_merge_huggingface_settings_extra( + model_settings: str, + env_params: Dict, + expected: Dict, + request: pytest.FixtureRequest, +): + assert expected == merge_huggingface_settings_extra( + request.getfixturevalue(model_settings), env_params + ) + + +def test_merge_huggingface_settings_extra_raises(model_settings_extra_none): + with pytest.raises(MissingHuggingFaceSettings): + merge_huggingface_settings_extra(model_settings_extra_none, {}) + + +@pytest.mark.parametrize( + "model_settings,env_params,expected", + [ + ( + "model_settings_extra_task", + '[{"name": "task", "value": "question-answering", "type": "STRING"}]', + HuggingFaceSettings( + task="question-answering", + task_suffix="", + pretrained_model="distilgpt2", + pretrained_tokenizer=None, + framework=None, + optimum_model=False, + device=-1, + inter_op_threads=None, + intra_op_threads=None, + ), + ), + ( + "model_settings_extra_task", + "[]", + HuggingFaceSettings( + task="text-generation", + task_suffix="", + pretrained_model="distilgpt2", + pretrained_tokenizer=None, + framework=None, + optimum_model=False, + device=-1, + inter_op_threads=None, + intra_op_threads=None, + ), + ), + ( + "model_settings_extra_none", + '[{"name": "task", "value": "question-answering", "type": "STRING"}]', + HuggingFaceSettings( + task="question-answering", + task_suffix="", + pretrained_model=None, + pretrained_tokenizer=None, + framework=None, + optimum_model=False, + device=-1, + inter_op_threads=None, + intra_op_threads=None, + ), + ), + ( + "model_settings_extra_empty", + '[{"name": "task", "value": "question-answering", "type": "STRING"}]', + HuggingFaceSettings( + task="question-answering", + task_suffix="", + pretrained_model=None, + pretrained_tokenizer=None, + framework=None, + optimum_model=False, + device=-1, + inter_op_threads=None, + intra_op_threads=None, + ), + ), + ], +) +def test_get_huggingface_settings( + model_settings: str, + env_params: str, + expected: HuggingFaceSettings, + request: pytest.FixtureRequest, + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setenv(PARAMETERS_ENV_NAME, env_params) + + assert expected == get_huggingface_settings(request.getfixturevalue(model_settings)) + + monkeypatch.delenv(PARAMETERS_ENV_NAME) + + +def test_get_huggingface_settings_raises(model_settings_extra_none): + with pytest.raises(MissingHuggingFaceSettings): + get_huggingface_settings(model_settings_extra_none)