-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LmStudioPromptDriver and LmStudioEmbeddingDriver
- Loading branch information
1 parent
29481f7
commit f7ff4d9
Showing
9 changed files
with
216 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from __future__ import annotations | ||
from attrs import define, field, Factory | ||
import openai | ||
|
||
from griptape.drivers import OpenAiEmbeddingDriver | ||
|
||
|
||
@define | ||
class LmStudioEmbeddingDriver(OpenAiEmbeddingDriver): | ||
""" | ||
Attributes: | ||
base_url: API URL. Defaults to LM Studio's v1 API URL. | ||
client: Optionally provide custom `openai.OpenAI` client. | ||
""" | ||
|
||
base_url: str = field(default="http://localhost:1234/v1", kw_only=True, metadata={"serializable": True}) | ||
client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(base_url=self.base_url), takes_self=True)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from __future__ import annotations | ||
from attrs import define, field, Factory | ||
from griptape.drivers.prompt.openai_chat_prompt_driver import OpenAiChatPromptDriver | ||
|
||
import openai | ||
|
||
|
||
@define | ||
class LmStudioPromptDriver(OpenAiChatPromptDriver): | ||
""" | ||
Attributes: | ||
base_url: API URL. Defaults to LM Studio's v1 API URL. | ||
client: Optionally provide custom `openai.OpenAI` client. | ||
""" | ||
|
||
base_url: str = field(default="http://localhost:1234/v1", kw_only=True, metadata={"serializable": True}) | ||
client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(base_url=self.base_url), takes_self=True)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
tests/unit/drivers/embedding/test_lm_studio_embedding_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from unittest.mock import Mock | ||
import pytest | ||
from griptape.drivers import LmStudioEmbeddingDriver | ||
from griptape.tokenizers import OpenAiTokenizer | ||
|
||
|
||
class TestLmStudioEmbeddingDriver: | ||
@pytest.fixture(autouse=True) | ||
def mock_openai(self, mocker): | ||
mock_chat_create = mocker.patch("openai.OpenAI").return_value.embeddings.create | ||
|
||
mock_embedding = Mock() | ||
mock_embedding.embedding = [0, 1, 0] | ||
mock_response = Mock() | ||
mock_response.data = [mock_embedding] | ||
|
||
mock_chat_create.return_value = mock_response | ||
|
||
return mock_chat_create | ||
|
||
def test_init(self): | ||
assert LmStudioEmbeddingDriver() | ||
|
||
def test_try_embed_chunk(self): | ||
assert LmStudioEmbeddingDriver().try_embed_chunk("foobar") == [0, 1, 0] | ||
|
||
@pytest.mark.parametrize("model", OpenAiTokenizer.EMBEDDING_MODELS) | ||
def test_try_embed_chunk_replaces_newlines_in_older_ada_models(self, model, mock_openai): | ||
LmStudioEmbeddingDriver(model=model).try_embed_chunk("foo\nbar") | ||
assert mock_openai.call_args.kwargs["input"] == "foo bar" if model.endswith("001") else "foo\nbar" |
104 changes: 104 additions & 0 deletions
104
tests/unit/drivers/prompt/test_lm_studio_prompt_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent | ||
from griptape.drivers import LmStudioPromptDriver | ||
from griptape.common import PromptStack | ||
from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact | ||
from unittest.mock import Mock | ||
import pytest | ||
|
||
|
||
class TestLmStudioPromptDriver: | ||
@pytest.fixture | ||
def mock_chat_completion_create(self, mocker): | ||
mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create | ||
mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id") | ||
mock_function.name = "MockTool_test" | ||
mock_chat_create.return_value = Mock( | ||
headers={}, | ||
choices=[Mock(message=Mock(content="model-output"))], | ||
usage=Mock(prompt_tokens=5, completion_tokens=10), | ||
) | ||
|
||
return mock_chat_create | ||
|
||
@pytest.fixture | ||
def mock_chat_completion_stream_create(self, mocker): | ||
mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create | ||
mock_tool_call_delta_header = Mock() | ||
mock_tool_call_delta_header.name = "MockTool_test" | ||
mock_tool_call_delta_body = Mock(arguments='{"foo": "bar"}') | ||
mock_tool_call_delta_body.name = None | ||
|
||
mock_chat_create.return_value = iter( | ||
[ | ||
Mock(choices=[Mock(delta=Mock(content="model-output", tool_calls=None))], usage=None), | ||
Mock(choices=None, usage=Mock(prompt_tokens=5, completion_tokens=10)), | ||
] | ||
) | ||
return mock_chat_create | ||
|
||
def test_init(self): | ||
assert LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF") | ||
|
||
def test_try_run(self, mock_client): | ||
# Given | ||
prompt_stack = PromptStack() | ||
prompt_stack.add_system_message("system-input") | ||
prompt_stack.add_user_message("user-input") | ||
prompt_stack.add_user_message( | ||
ListArtifact( | ||
[TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] | ||
) | ||
) | ||
prompt_stack.add_assistant_message("assistant-input") | ||
driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF") | ||
expected_messages = [ | ||
{"role": "system", "content": "system-input"}, | ||
{"role": "user", "content": "user-input"}, | ||
{"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, | ||
{"role": "assistant", "content": "assistant-input"}, | ||
] | ||
|
||
# When | ||
message = driver.try_run(prompt_stack) | ||
|
||
# Then | ||
mock_client.return_value.chat.assert_called_once_with( | ||
messages=expected_messages, | ||
model=driver.model, | ||
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, | ||
) | ||
assert message.value == "model-output" | ||
assert message.usage.input_tokens is None | ||
assert message.usage.output_tokens is None | ||
|
||
def test_try_stream_run(self, mock_stream_client): | ||
# Given | ||
prompt_stack = PromptStack() | ||
prompt_stack.add_system_message("system-input") | ||
prompt_stack.add_user_message("user-input") | ||
prompt_stack.add_user_message( | ||
ListArtifact( | ||
[TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] | ||
) | ||
) | ||
prompt_stack.add_assistant_message("assistant-input") | ||
expected_messages = [ | ||
{"role": "system", "content": "system-input"}, | ||
{"role": "user", "content": "user-input"}, | ||
{"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, | ||
{"role": "assistant", "content": "assistant-input"}, | ||
] | ||
driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True) | ||
|
||
# When | ||
text_artifact = next(driver.try_stream(prompt_stack)) | ||
|
||
# Then | ||
mock_stream_client.return_value.chat.assert_called_once_with( | ||
messages=expected_messages, | ||
model=driver.model, | ||
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, | ||
stream=True, | ||
) | ||
if isinstance(text_artifact, TextDeltaMessageContent): | ||
assert text_artifact.text == "model-output" |