From 7d74f73f2b623dad12e48884b44fcee41066460a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 14 Nov 2024 22:39:44 +0000 Subject: [PATCH] Fix hugging face tests for some reason (#1349) --- CHANGELOG.md | 1 - ...est_hugging_face_pipeline_prompt_driver.py | 44 +++++++++---------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 298e6dd08..c22d6a7f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,7 +67,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AmazonBedrockPromptDriver` not working without setting `max_tokens`. - `BaseImageGenerationTask` no longer prevents setting `negative_rulesets` _and_ `negative_rules` at the same time. - ## \[0.34.3\] - 2024-11-13 ### Fixed diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index e3c99f402..af52ca4e9 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -7,14 +7,12 @@ class TestHuggingFacePipelinePromptDriver: @pytest.fixture(autouse=True) def mock_pipeline(self, mocker): - return mocker.patch("transformers.pipeline") + mock_pipeline = mocker.patch("transformers.pipeline") + mock_pipeline = mock_pipeline.return_value + mock_pipeline.task = "text-generation" + mock_pipeline.return_value = [{"generated_text": [{"content": "model-output"}]}] - @pytest.fixture(autouse=True) - def mock_provider(self, mock_pipeline): - mock_provider = mock_pipeline.return_value - mock_provider.task = "text-generation" - mock_provider.return_value = [{"generated_text": [{"content": "model-output"}]}] - return mock_provider + return mock_pipeline @pytest.fixture(autouse=True) def mock_autotokenizer(self, mocker): @@ -41,27 +39,27 @@ def messages(self): {"role": "assistant", "content": "assistant-input"}, ] - def test_init(self): - assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42) + def test_init(self, mock_pipeline): + assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42, pipeline=mock_pipeline) def test_try_run(self, prompt_stack, messages, mock_pipeline): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, extra_params={"foo": "bar"}) + driver = HuggingFacePipelinePromptDriver( + model="foo", max_tokens=42, extra_params={"foo": "bar"}, pipeline=mock_pipeline + ) # When message = driver.try_run(prompt_stack) # Then - mock_pipeline.return_value.assert_called_once_with( - messages, max_new_tokens=42, temperature=0.1, do_sample=True, foo="bar" - ) + mock_pipeline.assert_called_once_with(messages, max_new_tokens=42, temperature=0.1, do_sample=True, foo="bar") assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack): + def test_try_stream(self, prompt_stack, mock_pipeline): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) + driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline) # When with pytest.raises(Exception) as e: @@ -70,10 +68,10 @@ def test_try_stream(self, prompt_stack): assert e.value.args[0] == "streaming is not supported" @pytest.mark.parametrize("choices", [[], [1, 2]]) - def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_provider, prompt_stack): + def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_pipeline, prompt_stack): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) - mock_provider.return_value = choices + driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline) + mock_pipeline.return_value = choices # When with pytest.raises(Exception) as e: @@ -82,10 +80,10 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_provi # Then assert e.value.args[0] == "completion with more than one choice is not supported yet" - def test_try_run_throws_when_non_list(self, mock_provider, prompt_stack): + def test_try_run_throws_when_non_list(self, mock_pipeline, prompt_stack): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) - mock_provider.return_value = {} + driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline) + mock_pipeline.return_value = {} # When with pytest.raises(Exception) as e: @@ -94,9 +92,9 @@ def test_try_run_throws_when_non_list(self, mock_provider, prompt_stack): # Then assert e.value.args[0] == "invalid output format" - def test_prompt_stack_to_string(self, prompt_stack): + def test_prompt_stack_to_string(self, prompt_stack, mock_pipeline): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) + driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline) # When result = driver.prompt_stack_to_string(prompt_stack)