Skip to content

Commit

Permalink
feat(huggingface): Merge model settings
Browse files Browse the repository at this point in the history
If settings for the model are provided in both the
`model-settings.json` *and* environment variables, merge them.

Precedence is given to environment variables.
  • Loading branch information
Jesse Claven committed Aug 17, 2023
1 parent 08e7c86 commit a7835a4
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 9 deletions.
55 changes: 46 additions & 9 deletions runtimes/huggingface/mlserver_huggingface/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import orjson

from typing import Optional, Dict
from typing import Optional, Dict, Union
from pydantic import BaseSettings
from distutils.util import strtobool
from transformers.pipelines import SUPPORTED_TASKS
Expand Down Expand Up @@ -110,14 +110,22 @@ def task_name(self):
return self.task


def parse_parameters_from_env() -> Dict:
def parse_parameters_from_env() -> Dict[str, Union[str, bool, float, int]]:
"""
This method parses the environment variables injected via SCv1.
At least an empty dict is always returned.
"""
# TODO: Once support for SCv1 is deprecated, we should remove this method and rely
# purely on settings coming via the `model-settings.json` file.
parameters = orjson.loads(os.environ.get(PARAMETERS_ENV_NAME, "[]"))

parsed_parameters: Dict[str, Union[str, bool, float, int]] = {}

# Guard: Exit early if there's no parameters
if len(parameters) == 0:
return parsed_parameters

type_dict = {
"INT": int,
"FLOAT": float,
Expand All @@ -126,7 +134,6 @@ def parse_parameters_from_env() -> Dict:
"BOOL": bool,
}

parsed_parameters = {}
for param in parameters:
name = param.get("name")
value = param.get("value")
Expand All @@ -140,17 +147,15 @@ def parse_parameters_from_env() -> Dict:
raise InvalidModelParameter(name, value, type_)
except KeyError:
raise InvalidModelParameterType(type_)

return parsed_parameters


def get_huggingface_settings(model_settings: ModelSettings) -> HuggingFaceSettings:
env_params = parse_parameters_from_env()
if not env_params and (
not model_settings.parameters or not model_settings.parameters.extra
):
raise MissingHuggingFaceSettings()
"""Get the HuggingFace settings provided to the runtime"""

extra = env_params or model_settings.parameters.extra # type: ignore
env_params = parse_parameters_from_env()
extra = merge_huggingface_settings_extra(model_settings, env_params)
hf_settings = HuggingFaceSettings(**extra) # type: ignore

if hf_settings.task not in SUPPORTED_TASKS:
Expand All @@ -161,3 +166,35 @@ def get_huggingface_settings(model_settings: ModelSettings) -> HuggingFaceSettin
raise InvalidOptimumTask(hf_settings.task, SUPPORTED_OPTIMUM_TASKS.keys())

return hf_settings


def merge_huggingface_settings_extra(
model_settings: ModelSettings, env_params: Dict[str, Union[str, bool, float, int]]
) -> Dict[str, Union[str, bool, float, int]]:
"""
This function returns the Extra field of the Settings.
It merges them, iff they're both present, from the
environment AND model settings file. Precedence is
giving to the environment.
"""

# Both `parameters` and `extra` are Optional, so we
# need to get the value, or nothing.
settings_params = (
model_settings.parameters.extra
if model_settings.parameters is not None
else None
)

if settings_params is None and env_params == {}:
# There must be settings provided by at least the environment OR model settings
raise MissingHuggingFaceSettings()

# Set the default value
settings_params = settings_params or {}

# Overwrite any conflicting keys, giving precedence to the environment
settings_params.update(env_params)

return settings_params
168 changes: 168 additions & 0 deletions runtimes/huggingface/tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import pytest

from typing import Dict

from mlserver.settings import ModelSettings, ModelParameters

from mlserver_huggingface.runtime import HuggingFaceRuntime
from mlserver_huggingface.settings import (
HuggingFaceSettings,
PARAMETERS_ENV_NAME,
get_huggingface_settings,
merge_huggingface_settings_extra,
)
from mlserver_huggingface.errors import MissingHuggingFaceSettings


@pytest.fixture()
def model_settings_extra_task():
return ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(
extra={"task": "text-generation", "pretrained_model": "distilgpt2"}
),
)


@pytest.fixture()
def model_settings_extra_none():
return ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(extra=None),
)


@pytest.fixture()
def model_settings_extra_empty():
return ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(extra={}),
)


@pytest.mark.parametrize(
"model_settings,env_params,expected",
[
(
"model_settings_extra_task",
{"task": "question-answering"},
{"task": "question-answering", "pretrained_model": "distilgpt2"},
),
(
"model_settings_extra_task",
{},
{"task": "text-generation", "pretrained_model": "distilgpt2"},
),
(
"model_settings_extra_none",
{"task": "question-answering"},
{"task": "question-answering"},
),
(
"model_settings_extra_empty",
{"task": "question-answering"},
{"task": "question-answering"},
),
],
)
def test_merge_huggingface_settings_extra(
model_settings: str,
env_params: Dict,
expected: Dict,
request: pytest.FixtureRequest,
):
assert expected == merge_huggingface_settings_extra(
request.getfixturevalue(model_settings), env_params
)


def test_merge_huggingface_settings_extra_raises(model_settings_extra_none):
with pytest.raises(MissingHuggingFaceSettings):
merge_huggingface_settings_extra(model_settings_extra_none, {})


@pytest.mark.parametrize(
"model_settings,env_params,expected",
[
(
"model_settings_extra_task",
'[{"name": "task", "value": "question-answering", "type": "STRING"}]',
HuggingFaceSettings(
task="question-answering",
task_suffix="",
pretrained_model="distilgpt2",
pretrained_tokenizer=None,
framework=None,
optimum_model=False,
device=-1,
inter_op_threads=None,
intra_op_threads=None,
),
),
(
"model_settings_extra_task",
"[]",
HuggingFaceSettings(
task="text-generation",
task_suffix="",
pretrained_model="distilgpt2",
pretrained_tokenizer=None,
framework=None,
optimum_model=False,
device=-1,
inter_op_threads=None,
intra_op_threads=None,
),
),
(
"model_settings_extra_none",
'[{"name": "task", "value": "question-answering", "type": "STRING"}]',
HuggingFaceSettings(
task="question-answering",
task_suffix="",
pretrained_model=None,
pretrained_tokenizer=None,
framework=None,
optimum_model=False,
device=-1,
inter_op_threads=None,
intra_op_threads=None,
),
),
(
"model_settings_extra_empty",
'[{"name": "task", "value": "question-answering", "type": "STRING"}]',
HuggingFaceSettings(
task="question-answering",
task_suffix="",
pretrained_model=None,
pretrained_tokenizer=None,
framework=None,
optimum_model=False,
device=-1,
inter_op_threads=None,
intra_op_threads=None,
),
),
],
)
def test_get_huggingface_settings(
model_settings: str,
env_params: str,
expected: HuggingFaceSettings,
request: pytest.FixtureRequest,
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv(PARAMETERS_ENV_NAME, env_params)

assert expected == get_huggingface_settings(request.getfixturevalue(model_settings))

monkeypatch.delenv(PARAMETERS_ENV_NAME)


def test_get_huggingface_settings_raises(model_settings_extra_none):
with pytest.raises(MissingHuggingFaceSettings):
get_huggingface_settings(model_settings_extra_none)

0 comments on commit a7835a4

Please sign in to comment.