Skip to content

Commit

Permalink
Add LmStudioPromptDriver and LmStudioEmbeddingDriver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 10, 2024
1 parent 29481f7 commit f7ff4d9
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- `LmStudioPromptDriver` for generating chat completions with LmStudio models.
- `LmStudioEmbeddingDriver` for generating embeddings with LmStudio models.

### Changed

Expand Down
17 changes: 17 additions & 0 deletions docs/griptape-framework/drivers/embedding-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,23 @@ embeddings = embedding_driver.embed_string("Hello world!")
print(embeddings[:3])
```

### LM Studio

The [LmStudioEmbeddingDriver](../../reference/griptape/drivers/embedding/lmstudio_embedding_driver.md) uses the [LM Studio Embeddings API](https://lmstudio.ai/docs/text-embeddings).

```python title="PYTEST_IGNORE"
from griptape.drivers import LmStudioEmbeddingDriver

embedding_driver = LmStudioEmbeddingDriver(
model="nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.Q2_K",
)

embeddings = embedding_driver.embed_string("Hello world!")

# display the first 3 embeddings
print(embeddings[:3])
```

### Override Default Structure Embedding Driver
Here is how you can override the Embedding Driver that is used by default in Structures.

Expand Down
23 changes: 23 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,29 @@ agent = Agent(
agent.run("What color is the sky at different times of the day?")
```


### LM Studio

The [LmStudioPromptDriver](../../reference/griptape/drivers/prompt/lm_studio_prompt_driver.md) connects to the [Lm Studio Local Server](https://lmstudio.ai/docs/local-server).

```python title="PYTEST_IGNORE"
from griptape.structures import Agent
from griptape.drivers import LmStudioPromptDriver
from griptape.rules import Rule
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=LmStudioPromptDriver(
model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF"
)
),
rules=[Rule(value="You are a helpful coding assistant.")],
)

agent.run("How do I init and update a git submodule?")
```

### Hugging Face Hub

!!! info
Expand Down
4 changes: 4 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .prompt.google_prompt_driver import GooglePromptDriver
from .prompt.dummy_prompt_driver import DummyPromptDriver
from .prompt.ollama_prompt_driver import OllamaPromptDriver
from .prompt.lm_studio_prompt_driver import LmStudioPromptDriver

from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
Expand All @@ -27,6 +28,7 @@
from .embedding.google_embedding_driver import GoogleEmbeddingDriver
from .embedding.dummy_embedding_driver import DummyEmbeddingDriver
from .embedding.cohere_embedding_driver import CohereEmbeddingDriver
from .embedding.lm_studio_embedding_driver import LmStudioEmbeddingDriver

from .vector.base_vector_store_driver import BaseVectorStoreDriver
from .vector.local_vector_store_driver import LocalVectorStoreDriver
Expand Down Expand Up @@ -120,6 +122,7 @@
"GooglePromptDriver",
"DummyPromptDriver",
"OllamaPromptDriver",
"LmStudioPromptDriver",
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
Expand All @@ -135,6 +138,7 @@
"GoogleEmbeddingDriver",
"DummyEmbeddingDriver",
"CohereEmbeddingDriver",
"LmStudioEmbeddingDriver",
"BaseVectorStoreDriver",
"LocalVectorStoreDriver",
"PineconeVectorStoreDriver",
Expand Down
17 changes: 17 additions & 0 deletions griptape/drivers/embedding/lm_studio_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations
from attrs import define, field, Factory
import openai

from griptape.drivers import OpenAiEmbeddingDriver


@define
class LmStudioEmbeddingDriver(OpenAiEmbeddingDriver):
"""
Attributes:
base_url: API URL. Defaults to LM Studio's v1 API URL.
client: Optionally provide custom `openai.OpenAI` client.
"""

base_url: str = field(default="http://localhost:1234/v1", kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(base_url=self.base_url), takes_self=True))
17 changes: 17 additions & 0 deletions griptape/drivers/prompt/lm_studio_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations
from attrs import define, field, Factory
from griptape.drivers.prompt.openai_chat_prompt_driver import OpenAiChatPromptDriver

import openai


@define
class LmStudioPromptDriver(OpenAiChatPromptDriver):
"""
Attributes:
base_url: API URL. Defaults to LM Studio's v1 API URL.
client: Optionally provide custom `openai.OpenAI` client.
"""

base_url: str = field(default="http://localhost:1234/v1", kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(base_url=self.base_url), takes_self=True))
3 changes: 2 additions & 1 deletion griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:

for message in messages:
if message.is_text():
openai_messages.append({"role": message.role, "content": message.to_text()})
if message.to_text():
openai_messages.append({"role": message.role, "content": message.to_text()})
elif message.has_any_content_type(ActionResultMessageContent):
# Action results need to be expanded into separate messages.
openai_messages.extend(
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/drivers/embedding/test_lm_studio_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from unittest.mock import Mock
import pytest
from griptape.drivers import LmStudioEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer


class TestLmStudioEmbeddingDriver:
@pytest.fixture(autouse=True)
def mock_openai(self, mocker):
mock_chat_create = mocker.patch("openai.OpenAI").return_value.embeddings.create

mock_embedding = Mock()
mock_embedding.embedding = [0, 1, 0]
mock_response = Mock()
mock_response.data = [mock_embedding]

mock_chat_create.return_value = mock_response

return mock_chat_create

def test_init(self):
assert LmStudioEmbeddingDriver()

def test_try_embed_chunk(self):
assert LmStudioEmbeddingDriver().try_embed_chunk("foobar") == [0, 1, 0]

@pytest.mark.parametrize("model", OpenAiTokenizer.EMBEDDING_MODELS)
def test_try_embed_chunk_replaces_newlines_in_older_ada_models(self, model, mock_openai):
LmStudioEmbeddingDriver(model=model).try_embed_chunk("foo\nbar")
assert mock_openai.call_args.kwargs["input"] == "foo bar" if model.endswith("001") else "foo\nbar"
104 changes: 104 additions & 0 deletions tests/unit/drivers/prompt/test_lm_studio_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
from griptape.drivers import LmStudioPromptDriver
from griptape.common import PromptStack
from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact
from unittest.mock import Mock
import pytest


class TestLmStudioPromptDriver:
@pytest.fixture
def mock_chat_completion_create(self, mocker):
mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create
mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id")
mock_function.name = "MockTool_test"
mock_chat_create.return_value = Mock(
headers={},
choices=[Mock(message=Mock(content="model-output"))],
usage=Mock(prompt_tokens=5, completion_tokens=10),
)

return mock_chat_create

@pytest.fixture
def mock_chat_completion_stream_create(self, mocker):
mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create
mock_tool_call_delta_header = Mock()
mock_tool_call_delta_header.name = "MockTool_test"
mock_tool_call_delta_body = Mock(arguments='{"foo": "bar"}')
mock_tool_call_delta_body.name = None

mock_chat_create.return_value = iter(
[
Mock(choices=[Mock(delta=Mock(content="model-output", tool_calls=None))], usage=None),
Mock(choices=None, usage=Mock(prompt_tokens=5, completion_tokens=10)),
]
)
return mock_chat_create

def test_init(self):
assert LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF")

def test_try_run(self, mock_client):
# Given
prompt_stack = PromptStack()
prompt_stack.add_system_message("system-input")
prompt_stack.add_user_message("user-input")
prompt_stack.add_user_message(
ListArtifact(
[TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)]
)
)
prompt_stack.add_assistant_message("assistant-input")
driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF")
expected_messages = [
{"role": "system", "content": "system-input"},
{"role": "user", "content": "user-input"},
{"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]},
{"role": "assistant", "content": "assistant-input"},
]

# When
message = driver.try_run(prompt_stack)

# Then
mock_client.return_value.chat.assert_called_once_with(
messages=expected_messages,
model=driver.model,
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens},
)
assert message.value == "model-output"
assert message.usage.input_tokens is None
assert message.usage.output_tokens is None

def test_try_stream_run(self, mock_stream_client):
# Given
prompt_stack = PromptStack()
prompt_stack.add_system_message("system-input")
prompt_stack.add_user_message("user-input")
prompt_stack.add_user_message(
ListArtifact(
[TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)]
)
)
prompt_stack.add_assistant_message("assistant-input")
expected_messages = [
{"role": "system", "content": "system-input"},
{"role": "user", "content": "user-input"},
{"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]},
{"role": "assistant", "content": "assistant-input"},
]
driver = LmStudioPromptDriver(model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True)

# When
text_artifact = next(driver.try_stream(prompt_stack))

# Then
mock_stream_client.return_value.chat.assert_called_once_with(
messages=expected_messages,
model=driver.model,
options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens},
stream=True,
)
if isinstance(text_artifact, TextDeltaMessageContent):
assert text_artifact.text == "model-output"

0 comments on commit f7ff4d9

Please sign in to comment.