Skip to content

Commit

Permalink
refactor(huggingface): Add custom type for Extras
Browse files Browse the repository at this point in the history
The possible types are constrained in what's allowed. Specify this new
type that follows those constraints.
  • Loading branch information
Jesse Claven committed Aug 17, 2023
1 parent a7835a4 commit 31aeb44
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
33 changes: 18 additions & 15 deletions runtimes/huggingface/mlserver_huggingface/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import orjson

from typing import Optional, Dict, Union
from typing import Optional, Dict, Union, NewType
from pydantic import BaseSettings
from distutils.util import strtobool
from transformers.pipelines import SUPPORTED_TASKS
Expand Down Expand Up @@ -110,7 +110,18 @@ def task_name(self):
return self.task


def parse_parameters_from_env() -> Dict[str, Union[str, bool, float, int]]:
EXTRA_TYPE_DICT = {
"INT": int,
"FLOAT": float,
"DOUBLE": float,
"STRING": str,
"BOOL": bool,
}

ExtraDict = NewType("ExtraDict", Dict[str, Union[str, bool, float, int]])


def parse_parameters_from_env() -> ExtraDict:
"""
This method parses the environment variables injected via SCv1.
Expand All @@ -120,20 +131,12 @@ def parse_parameters_from_env() -> Dict[str, Union[str, bool, float, int]]:
# purely on settings coming via the `model-settings.json` file.
parameters = orjson.loads(os.environ.get(PARAMETERS_ENV_NAME, "[]"))

parsed_parameters: Dict[str, Union[str, bool, float, int]] = {}
parsed_parameters: ExtraDict = ExtraDict({})

# Guard: Exit early if there's no parameters
if len(parameters) == 0:
return parsed_parameters

type_dict = {
"INT": int,
"FLOAT": float,
"DOUBLE": float,
"STRING": str,
"BOOL": bool,
}

for param in parameters:
name = param.get("name")
value = param.get("value")
Expand All @@ -142,7 +145,7 @@ def parse_parameters_from_env() -> Dict[str, Union[str, bool, float, int]]:
parsed_parameters[name] = bool(strtobool(value))
else:
try:
parsed_parameters[name] = type_dict[type_](value)
parsed_parameters[name] = EXTRA_TYPE_DICT[type_](value)
except ValueError:
raise InvalidModelParameter(name, value, type_)
except KeyError:
Expand All @@ -169,8 +172,8 @@ def get_huggingface_settings(model_settings: ModelSettings) -> HuggingFaceSettin


def merge_huggingface_settings_extra(
model_settings: ModelSettings, env_params: Dict[str, Union[str, bool, float, int]]
) -> Dict[str, Union[str, bool, float, int]]:
model_settings: ModelSettings, env_params: ExtraDict
) -> ExtraDict:
"""
This function returns the Extra field of the Settings.
Expand All @@ -197,4 +200,4 @@ def merge_huggingface_settings_extra(
# Overwrite any conflicting keys, giving precedence to the environment
settings_params.update(env_params)

return settings_params
return ExtraDict(settings_params)
3 changes: 2 additions & 1 deletion runtimes/huggingface/tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mlserver_huggingface.runtime import HuggingFaceRuntime
from mlserver_huggingface.settings import (
HuggingFaceSettings,
ExtraDict,
PARAMETERS_ENV_NAME,
get_huggingface_settings,
merge_huggingface_settings_extra,
Expand Down Expand Up @@ -70,7 +71,7 @@ def model_settings_extra_empty():
)
def test_merge_huggingface_settings_extra(
model_settings: str,
env_params: Dict,
env_params: ExtraDict,
expected: Dict,
request: pytest.FixtureRequest,
):
Expand Down

0 comments on commit 31aeb44

Please sign in to comment.