Skip to content

Commit

Permalink
Fix hugging face tests for some reason (#1349)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Nov 14, 2024
1 parent dd84468 commit 7d74f73
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `AmazonBedrockPromptDriver` not working without setting `max_tokens`.
- `BaseImageGenerationTask` no longer prevents setting `negative_rulesets` _and_ `negative_rules` at the same time.


## \[0.34.3\] - 2024-11-13

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
class TestHuggingFacePipelinePromptDriver:
@pytest.fixture(autouse=True)
def mock_pipeline(self, mocker):
return mocker.patch("transformers.pipeline")
mock_pipeline = mocker.patch("transformers.pipeline")
mock_pipeline = mock_pipeline.return_value
mock_pipeline.task = "text-generation"
mock_pipeline.return_value = [{"generated_text": [{"content": "model-output"}]}]

@pytest.fixture(autouse=True)
def mock_provider(self, mock_pipeline):
mock_provider = mock_pipeline.return_value
mock_provider.task = "text-generation"
mock_provider.return_value = [{"generated_text": [{"content": "model-output"}]}]
return mock_provider
return mock_pipeline

@pytest.fixture(autouse=True)
def mock_autotokenizer(self, mocker):
Expand All @@ -41,27 +39,27 @@ def messages(self):
{"role": "assistant", "content": "assistant-input"},
]

def test_init(self):
assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42)
def test_init(self, mock_pipeline):
assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42, pipeline=mock_pipeline)

def test_try_run(self, prompt_stack, messages, mock_pipeline):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, extra_params={"foo": "bar"})
driver = HuggingFacePipelinePromptDriver(
model="foo", max_tokens=42, extra_params={"foo": "bar"}, pipeline=mock_pipeline
)

# When
message = driver.try_run(prompt_stack)

# Then
mock_pipeline.return_value.assert_called_once_with(
messages, max_new_tokens=42, temperature=0.1, do_sample=True, foo="bar"
)
mock_pipeline.assert_called_once_with(messages, max_new_tokens=42, temperature=0.1, do_sample=True, foo="bar")
assert message.value == "model-output"
assert message.usage.input_tokens == 3
assert message.usage.output_tokens == 3

def test_try_stream(self, prompt_stack):
def test_try_stream(self, prompt_stack, mock_pipeline):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42)
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline)

# When
with pytest.raises(Exception) as e:
Expand All @@ -70,10 +68,10 @@ def test_try_stream(self, prompt_stack):
assert e.value.args[0] == "streaming is not supported"

@pytest.mark.parametrize("choices", [[], [1, 2]])
def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_provider, prompt_stack):
def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_pipeline, prompt_stack):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42)
mock_provider.return_value = choices
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline)
mock_pipeline.return_value = choices

# When
with pytest.raises(Exception) as e:
Expand All @@ -82,10 +80,10 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_provi
# Then
assert e.value.args[0] == "completion with more than one choice is not supported yet"

def test_try_run_throws_when_non_list(self, mock_provider, prompt_stack):
def test_try_run_throws_when_non_list(self, mock_pipeline, prompt_stack):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42)
mock_provider.return_value = {}
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline)
mock_pipeline.return_value = {}

# When
with pytest.raises(Exception) as e:
Expand All @@ -94,9 +92,9 @@ def test_try_run_throws_when_non_list(self, mock_provider, prompt_stack):
# Then
assert e.value.args[0] == "invalid output format"

def test_prompt_stack_to_string(self, prompt_stack):
def test_prompt_stack_to_string(self, prompt_stack, mock_pipeline):
# Given
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42)
driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline)

# When
result = driver.prompt_stack_to_string(prompt_stack)
Expand Down

0 comments on commit 7d74f73

Please sign in to comment.