From 31aeb44889580a83b32e371593be73a15e14ab3d Mon Sep 17 00:00:00 2001 From: Jesse Claven Date: Thu, 17 Aug 2023 10:56:31 +0100 Subject: [PATCH] refactor(huggingface): Add custom type for Extras The possible types are constrained in what's allowed. Specify this new type that follows those constraints. --- .../mlserver_huggingface/settings.py | 33 ++++++++++--------- runtimes/huggingface/tests/test_settings.py | 3 +- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/runtimes/huggingface/mlserver_huggingface/settings.py b/runtimes/huggingface/mlserver_huggingface/settings.py index b70c1c453..668d21410 100644 --- a/runtimes/huggingface/mlserver_huggingface/settings.py +++ b/runtimes/huggingface/mlserver_huggingface/settings.py @@ -1,7 +1,7 @@ import os import orjson -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, NewType from pydantic import BaseSettings from distutils.util import strtobool from transformers.pipelines import SUPPORTED_TASKS @@ -110,7 +110,18 @@ def task_name(self): return self.task -def parse_parameters_from_env() -> Dict[str, Union[str, bool, float, int]]: +EXTRA_TYPE_DICT = { + "INT": int, + "FLOAT": float, + "DOUBLE": float, + "STRING": str, + "BOOL": bool, +} + +ExtraDict = NewType("ExtraDict", Dict[str, Union[str, bool, float, int]]) + + +def parse_parameters_from_env() -> ExtraDict: """ This method parses the environment variables injected via SCv1. @@ -120,20 +131,12 @@ def parse_parameters_from_env() -> Dict[str, Union[str, bool, float, int]]: # purely on settings coming via the `model-settings.json` file. parameters = orjson.loads(os.environ.get(PARAMETERS_ENV_NAME, "[]")) - parsed_parameters: Dict[str, Union[str, bool, float, int]] = {} + parsed_parameters: ExtraDict = ExtraDict({}) # Guard: Exit early if there's no parameters if len(parameters) == 0: return parsed_parameters - type_dict = { - "INT": int, - "FLOAT": float, - "DOUBLE": float, - "STRING": str, - "BOOL": bool, - } - for param in parameters: name = param.get("name") value = param.get("value") @@ -142,7 +145,7 @@ def parse_parameters_from_env() -> Dict[str, Union[str, bool, float, int]]: parsed_parameters[name] = bool(strtobool(value)) else: try: - parsed_parameters[name] = type_dict[type_](value) + parsed_parameters[name] = EXTRA_TYPE_DICT[type_](value) except ValueError: raise InvalidModelParameter(name, value, type_) except KeyError: @@ -169,8 +172,8 @@ def get_huggingface_settings(model_settings: ModelSettings) -> HuggingFaceSettin def merge_huggingface_settings_extra( - model_settings: ModelSettings, env_params: Dict[str, Union[str, bool, float, int]] -) -> Dict[str, Union[str, bool, float, int]]: + model_settings: ModelSettings, env_params: ExtraDict +) -> ExtraDict: """ This function returns the Extra field of the Settings. @@ -197,4 +200,4 @@ def merge_huggingface_settings_extra( # Overwrite any conflicting keys, giving precedence to the environment settings_params.update(env_params) - return settings_params + return ExtraDict(settings_params) diff --git a/runtimes/huggingface/tests/test_settings.py b/runtimes/huggingface/tests/test_settings.py index 13fd95205..01a4c71a9 100644 --- a/runtimes/huggingface/tests/test_settings.py +++ b/runtimes/huggingface/tests/test_settings.py @@ -7,6 +7,7 @@ from mlserver_huggingface.runtime import HuggingFaceRuntime from mlserver_huggingface.settings import ( HuggingFaceSettings, + ExtraDict, PARAMETERS_ENV_NAME, get_huggingface_settings, merge_huggingface_settings_extra, @@ -70,7 +71,7 @@ def model_settings_extra_empty(): ) def test_merge_huggingface_settings_extra( model_settings: str, - env_params: Dict, + env_params: ExtraDict, expected: Dict, request: pytest.FixtureRequest, ):