Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 10, 2024
1 parent fd1c558 commit f87bd01
Showing 1 changed file with 33 additions and 34 deletions.
67 changes: 33 additions & 34 deletions tests/unit/drivers/prompt/test_lm_studio_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def mock_chat_completion_create(self, mocker):
mock_function.name = "MockTool_test"
mock_chat_create.return_value = Mock(
headers={},
choices=[Mock(message=Mock(content="model-output"))],
choices=[Mock(message=Mock(content="model-output", tool_calls=None))],
usage=Mock(prompt_tokens=5, completion_tokens=10),
)

Expand All @@ -36,11 +36,8 @@ def mock_chat_completion_stream_create(self, mocker):
)
return mock_chat_create

def test_init(self):
assert LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF")

def test_try_run(self, mock_client):
# Given
@pytest.fixture
def prompt_stack(self):
prompt_stack = PromptStack()
prompt_stack.add_system_message("system-input")
prompt_stack.add_user_message("user-input")
Expand All @@ -50,55 +47,57 @@ def test_try_run(self, mock_client):
)
)
prompt_stack.add_assistant_message("assistant-input")
driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF")
expected_messages = [

return prompt_stack

@pytest.fixture
def messages(self):
return [
{"role": "system", "content": "system-input"},
{"role": "user", "content": "user-input"},
{"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]},
{
"role": "user",
"content": [
{"type": "text", "text": "user-input"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,aW1hZ2UtZGF0YQ=="}},
],
},
{"role": "assistant", "content": "assistant-input"},
]

def test_init(self):
assert LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF")

def test_try_run(self, mock_chat_completion_create, prompt_stack, messages):
# Given
driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF")
# When
message = driver.try_run(prompt_stack)

# Then
mock_client.return_value.chat.assert_called_once_with(
messages=expected_messages,
model=driver.model,
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens},
mock_chat_completion_create.assert_called_once_with(
messages=messages, model=driver.model, temperature=driver.temperature, user=driver.user, seed=driver.seed
)
assert message.value == "model-output"
assert message.usage.input_tokens is None
assert message.usage.output_tokens is None
assert message.usage.input_tokens == 5
assert message.usage.output_tokens == 10

def test_try_stream_run(self, mock_stream_client):
def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages):
# Given
prompt_stack = PromptStack()
prompt_stack.add_system_message("system-input")
prompt_stack.add_user_message("user-input")
prompt_stack.add_user_message(
ListArtifact(
[TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)]
)
)
prompt_stack.add_assistant_message("assistant-input")
expected_messages = [
{"role": "system", "content": "system-input"},
{"role": "user", "content": "user-input"},
{"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]},
{"role": "assistant", "content": "assistant-input"},
]
driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True)

# When
text_artifact = next(driver.try_stream(prompt_stack))

# Then
mock_stream_client.return_value.chat.assert_called_once_with(
messages=expected_messages,
mock_chat_completion_stream_create.assert_called_once_with(
messages=messages,
model=driver.model,
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens},
temperature=driver.temperature,
user=driver.user,
seed=driver.seed,
stream=True,
stream_options={"include_usage": True},
)
if isinstance(text_artifact, TextDeltaMessageContent):
assert text_artifact.text == "model-output"

0 comments on commit f87bd01

Please sign in to comment.