diff --git a/tests/unit/drivers/prompt/test_lm_studio_prompt_driver.py b/tests/unit/drivers/prompt/test_lm_studio_prompt_driver.py index 28847e746e..4aa73ce40c 100644 --- a/tests/unit/drivers/prompt/test_lm_studio_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_lm_studio_prompt_driver.py @@ -14,7 +14,7 @@ def mock_chat_completion_create(self, mocker): mock_function.name = "MockTool_test" mock_chat_create.return_value = Mock( headers={}, - choices=[Mock(message=Mock(content="model-output"))], + choices=[Mock(message=Mock(content="model-output", tool_calls=None))], usage=Mock(prompt_tokens=5, completion_tokens=10), ) @@ -36,11 +36,8 @@ def mock_chat_completion_stream_create(self, mocker): ) return mock_chat_create - def test_init(self): - assert LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF") - - def test_try_run(self, mock_client): - # Given + @pytest.fixture + def prompt_stack(self): prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -50,55 +47,57 @@ def test_try_run(self, mock_client): ) ) prompt_stack.add_assistant_message("assistant-input") - driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF") - expected_messages = [ + + return prompt_stack + + @pytest.fixture + def messages(self): + return [ {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, - {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, + { + "role": "user", + "content": [ + {"type": "text", "text": "user-input"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,aW1hZ2UtZGF0YQ=="}}, + ], + }, {"role": "assistant", "content": "assistant-input"}, ] + def test_init(self): + assert LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF") + + def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): + # Given + driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF") # When message = driver.try_run(prompt_stack) # Then - mock_client.return_value.chat.assert_called_once_with( - messages=expected_messages, - model=driver.model, - options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + mock_chat_completion_create.assert_called_once_with( + messages=messages, model=driver.model, temperature=driver.temperature, user=driver.user, seed=driver.seed ) assert message.value == "model-output" - assert message.usage.input_tokens is None - assert message.usage.output_tokens is None + assert message.usage.input_tokens == 5 + assert message.usage.output_tokens == 10 - def test_try_stream_run(self, mock_stream_client): + def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message( - ListArtifact( - [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] - ) - ) - prompt_stack.add_assistant_message("assistant-input") - expected_messages = [ - {"role": "system", "content": "system-input"}, - {"role": "user", "content": "user-input"}, - {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, - {"role": "assistant", "content": "assistant-input"}, - ] driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True) # When text_artifact = next(driver.try_stream(prompt_stack)) # Then - mock_stream_client.return_value.chat.assert_called_once_with( - messages=expected_messages, + mock_chat_completion_stream_create.assert_called_once_with( + messages=messages, model=driver.model, - options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + temperature=driver.temperature, + user=driver.user, + seed=driver.seed, stream=True, + stream_options={"include_usage": True}, ) if isinstance(text_artifact, TextDeltaMessageContent): assert text_artifact.text == "model-output"