Skip to content

Commit

Permalink
Refactored tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vtaskow committed Aug 16, 2023
1 parent 77a7eb7 commit dce1316
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 57 deletions.
3 changes: 1 addition & 2 deletions runtimes/huggingface/mlserver_huggingface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def load_pipeline_from_settings(

model = hf_settings.pretrained_model
if not model:
if settings.parameters is not None:
model = settings.parameters.uri
model = settings.parameters.uri
tokenizer = hf_settings.pretrained_tokenizer
if not tokenizer:
tokenizer = hf_settings.pretrained_model
Expand Down
86 changes: 31 additions & 55 deletions runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,69 +50,45 @@ def test_load_pipeline(optimum_model: bool, expected):


@pytest.mark.parametrize(
"has_model_params, param_uri_value, expected",
"pretrained_model, parameters_uri, expected",
[
(True, "", "./some-pretrained-folder"),
(True, None, "./some-pretrained-folder"),
(True, "/some/folder/model-artefacts", "./some-pretrained-folder"),
(False, "", "./some-pretrained-folder"),
],
)
@patch("mlserver_huggingface.common._get_pipeline_class")
def test_pipeline_was_initialised_and_pretrained_model_takes_precedence(
mock_pipeline_factory,
has_model_params: bool,
param_uri_value: Optional[str],
expected: str,
):
mock_pipeline_factory.return_value = MagicMock()

hf_settings = HuggingFaceSettings(pretrained_model="./some-pretrained-folder")
model_params = None
if has_model_params:
model_params = ModelParameters(uri=param_uri_value)

model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=model_params,
)

_ = load_pipeline_from_settings(hf_settings, model_settings)

mock_pipeline_factory.return_value.assert_called_once()
pipeline_call_args = mock_pipeline_factory.return_value.call_args

assert pipeline_call_args.kwargs["model"] == expected


@pytest.mark.parametrize(
"empty_pretrained_model_value, has_model_params, param_uri_value, expected",
[
(None, True, None, None),
(None, True, "", ""),
(None, True, "/some/folder/model-artefacts", "/some/folder/model-artefacts"),
(None, False, "", None),
("", True, None, None),
("", True, "", ""),
("", True, "/some/folder/model-artefacts", "/some/folder/model-artefacts"),
("", False, "", ""),
(None, None, None),
(None, "", ""),
(None, "/some/folder/model-artefacts", "/some/folder/model-artefacts"),
("", None, None),
("", "", ""),
("", "/some/folder/model-artefacts", "/some/folder/model-artefacts"),
("some-model", None, "some-model"),
("some-model", "", "some-model"),
("some-model", "/some/folder/model-artefacts", "some-model"),
(
"/some/other/folder/model-artefacts",
None,
"/some/other/folder/model-artefacts",
),
(
"/some/other/folder/model-artefacts",
"",
"/some/other/folder/model-artefacts",
),
(
"/some/other/folder/model-artefacts",
"/some/folder/model-artefacts",
"/some/other/folder/model-artefacts",
),
],
)
@patch("mlserver_huggingface.common._get_pipeline_class")
def test_pipeline_was_initialised_when_pretrained_model_is_not_supplied(
def test_pipeline_was_initialised_with_correct_model_param(
mock_pipeline_factory,
empty_pretrained_model_value: Optional[str],
has_model_params: bool,
param_uri_value: Optional[str],
expected: str,
pretrained_model: Optional[str],
parameters_uri: Optional[str],
expected: Optional[str],
):
mock_pipeline_factory.return_value = MagicMock()

hf_settings = HuggingFaceSettings(pretrained_model=empty_pretrained_model_value)
model_params = None
if has_model_params:
model_params = ModelParameters(uri=param_uri_value)
hf_settings = HuggingFaceSettings(pretrained_model=pretrained_model)
model_params = ModelParameters(uri=parameters_uri)

model_settings = ModelSettings(
name="foo",
Expand Down

0 comments on commit dce1316

Please sign in to comment.