Skip to content

Commit

Permalink
Merge branch 'dev' into feature/lm-studio-prompt-driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 10, 2024
2 parents d9f483f + 9c44a62 commit d029a69
Show file tree
Hide file tree
Showing 36 changed files with 224 additions and 142 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- `LmStudioPromptDriver` for generating chat completions with LmStudio models.
- `LmStudioEmbeddingDriver` for generating embeddings with LmStudio models.
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.

### Changed

Expand Down
20 changes: 20 additions & 0 deletions docs/griptape-framework/drivers/embedding-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,28 @@ embeddings = driver.embed_string("Hello world!")

# display the first 3 embeddings
print(embeddings[:3])
```

### Ollama

!!! info
This driver requires the `drivers-embedding-ollama` [extra](../index.md#extras).

The [OllamaEmbeddingDriver](../../reference/griptape/drivers/embedding/ollama_embedding_driver.md) uses the [Ollama Embeddings API](https://ollama.com/blog/embedding-models).

```python title="PYTEST_IGNORE"
from griptape.drivers import OllamaEmbeddingDriver

driver = OllamaEmbeddingDriver(
model="all-minilm",
)

results = driver.embed_string("Hello world!")

# display the first 3 embeddings
print(results[:3])
```

### Amazon SageMaker Jumpstart

The [AmazonSageMakerJumpstartEmbeddingDriver](../../reference/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.md) uses the [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) to generate embeddings on AWS.
Expand Down
15 changes: 2 additions & 13 deletions griptape/artifacts/action_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,12 @@
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.common import Action
from griptape.common import ToolAction


@define()
class ActionArtifact(BaseArtifact, SerializableMixin):
"""Represents an instance of an LLM calling a Action.
Attributes:
tag: The tag (unique identifier) of the action.
name: The name (Tool name) of the action.
path: The path (Tool activity name) of the action.
input: The input (Tool params) of the action.
tool: The matched Tool of the action.
output: The output (Tool result) of the action.
"""

value: Action = field(metadata={"serializable": True})
value: ToolAction = field(metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> ActionArtifact:
raise NotImplementedError
6 changes: 4 additions & 2 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .action import Action
from .actions.base_action import BaseAction
from .actions.tool_action import ToolAction

from .prompt_stack.contents.base_message_content import BaseMessageContent
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
Expand Down Expand Up @@ -32,5 +33,6 @@
"ActionResultMessageContent",
"PromptStack",
"Reference",
"Action",
"BaseAction",
"ToolAction",
]
Empty file.
5 changes: 5 additions & 0 deletions griptape/common/actions/base_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from griptape.mixins import SerializableMixin
from abc import ABC


class BaseAction(SerializableMixin, ABC): ...
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,35 @@
from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.mixins import SerializableMixin
from griptape.common import BaseAction

if TYPE_CHECKING:
from griptape.tools import BaseTool


@define(kw_only=True)
class Action(SerializableMixin):
class ToolAction(BaseAction):
"""Represents an instance of an LLM using a Tool.
Attributes:
tag: The tag (unique identifier) of the action.
name: The name (Tool name) of the action.
path: The path (Tool activity name) of the action.
input: The input (Tool params) of the action.
tool: The matched Tool of the action.
output: The output (Tool result) of the action.
"""

tag: str = field(metadata={"serializable": True})
name: str = field(metadata={"serializable": True})
path: Optional[str] = field(default=None, metadata={"serializable": True})
input: dict = field(factory=dict, metadata={"serializable": True})
tool: Optional[BaseTool] = field(default=None)
output: Optional[BaseArtifact] = field(default=None, metadata={"serializable": True})
output: Optional[BaseArtifact] = field(default=None)

def __str__(self) -> str:
return json.dumps(self.to_dict())

def to_dict(self) -> dict:
return {"tag": self.tag, "name": self.name, "path": self.path, "input": self.input}

def to_native_tool_name(self) -> str:
parts = [self.name]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from attrs import define, field

from griptape.common import Action
from griptape.common import ToolAction
from griptape.artifacts import ActionArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ActionCallDeltaMessageContent

Expand Down Expand Up @@ -37,10 +37,10 @@ def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionCallMes
try:
parsed_input = json.loads(input)
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON input for Action") from exc
action = Action(tag=tag, name=name, path=path, input=parsed_input)
raise ValueError("Invalid JSON input for ToolAction") from exc
action = ToolAction(tag=tag, name=name, path=path, input=parsed_input)
else:
raise ValueError("Missing required fields for Action")
raise ValueError("Missing required fields for ToolAction")

artifact = ActionArtifact(value=action)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, Action
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ToolAction


@define
class ActionResultMessageContent(BaseMessageContent):
artifact: BaseArtifact = field(metadata={"serializable": True})
action: Action = field(metadata={"serializable": True})
action: ToolAction = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionResultMessageContent:
Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .embedding.dummy_embedding_driver import DummyEmbeddingDriver
from .embedding.cohere_embedding_driver import CohereEmbeddingDriver
from .embedding.lm_studio_embedding_driver import LmStudioEmbeddingDriver
from .embedding.ollama_embedding_driver import OllamaEmbeddingDriver

from .vector.base_vector_store_driver import BaseVectorStoreDriver
from .vector.local_vector_store_driver import LocalVectorStoreDriver
Expand Down Expand Up @@ -139,6 +140,7 @@
"DummyEmbeddingDriver",
"CohereEmbeddingDriver",
"LmStudioEmbeddingDriver",
"OllamaEmbeddingDriver",
"BaseVectorStoreDriver",
"LocalVectorStoreDriver",
"PineconeVectorStoreDriver",
Expand Down
28 changes: 28 additions & 0 deletions griptape/drivers/embedding/ollama_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from attrs import define, field, Factory
from griptape.utils import import_optional_dependency
from griptape.drivers import BaseEmbeddingDriver

if TYPE_CHECKING:
from ollama import Client


@define
class OllamaEmbeddingDriver(BaseEmbeddingDriver):
"""
Attributes:
model: Ollama embedding model name.
host: Optional Ollama host.
client: Ollama `Client`.
"""

model: str = field(kw_only=True, metadata={"serializable": True})
host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True),
kw_only=True,
)

def try_embed_chunk(self, chunk: str) -> list[float]:
return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Message,
TextDeltaMessageContent,
TextMessageContent,
Action,
ToolAction,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer
Expand Down Expand Up @@ -177,10 +177,10 @@ def __to_prompt_stack_message_content(self, content: dict) -> BaseMessageContent
if "text" in content:
return TextMessageContent(TextArtifact(content["text"]))
elif "toolUse" in content:
name, path = Action.from_native_tool_name(content["toolUse"]["name"])
name, path = ToolAction.from_native_tool_name(content["toolUse"]["name"])
return ActionCallMessageContent(
artifact=ActionArtifact(
value=Action(
value=ToolAction(
tag=content["toolUse"]["toolUseId"], name=name, path=path, input=content["toolUse"]["input"]
)
)
Expand All @@ -193,7 +193,7 @@ def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessa
content_block = event["contentBlockStart"]["start"]

if "toolUse" in content_block:
name, path = Action.from_native_tool_name(content_block["toolUse"]["name"])
name, path = ToolAction.from_native_tool_name(content_block["toolUse"]["name"])

return ActionCallDeltaMessageContent(
index=event["contentBlockStart"]["contentBlockIndex"],
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ImageMessageContent,
PromptStack,
Message,
Action,
ToolAction,
TextMessageContent,
)
from griptape.drivers import BasePromptDriver
Expand Down Expand Up @@ -183,11 +183,11 @@ def __to_prompt_stack_message_content(self, content: ContentBlock) -> BaseMessag
if content.type == "text":
return TextMessageContent(TextArtifact(content.text))
elif content.type == "tool_use":
name, path = Action.from_native_tool_name(content.name)
name, path = ToolAction.from_native_tool_name(content.name)

return ActionCallMessageContent(
artifact=ActionArtifact(
value=Action(tag=content.id, name=name, path=path, input=content.input) # pyright: ignore[reportArgumentType]
value=ToolAction(tag=content.id, name=name, path=path, input=content.input) # pyright: ignore[reportArgumentType]
)
)
else:
Expand All @@ -200,7 +200,7 @@ def __to_prompt_stack_delta_message_content(
content_block = event.content_block

if content_block.type == "tool_use":
name, path = Action.from_native_tool_name(content_block.name)
name, path = ToolAction.from_native_tool_name(content_block.name)

return ActionCallDeltaMessageContent(index=event.index, tag=content_block.id, name=name, path=path)
elif content_block.type == "text":
Expand Down
10 changes: 5 additions & 5 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Message,
TextMessageContent,
ActionResultMessageContent,
Action,
ToolAction,
)
from griptape.utils import import_optional_dependency
from griptape.tokenizers import BaseTokenizer
Expand Down Expand Up @@ -208,10 +208,10 @@ def __to_prompt_stack_message_content(self, response: NonStreamedChatResponse) -
[
ActionCallMessageContent(
ActionArtifact(
Action(
ToolAction(
tag=tool_call.name,
name=Action.from_native_tool_name(tool_call.name)[0],
path=Action.from_native_tool_name(tool_call.name)[1],
name=ToolAction.from_native_tool_name(tool_call.name)[0],
path=ToolAction.from_native_tool_name(tool_call.name)[1],
input=tool_call.parameters,
)
)
Expand All @@ -229,7 +229,7 @@ def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessag
if event.tool_call_delta is not None:
tool_call_delta = event.tool_call_delta
if tool_call_delta.name is not None:
name, path = Action.from_native_tool_name(tool_call_delta.name)
name, path = ToolAction.from_native_tool_name(tool_call_delta.name)

return ActionCallDeltaMessageContent(tag=tool_call_delta.name, name=name, path=path)
else:
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ActionResultMessageContent,
ActionCallDeltaMessageContent,
BaseDeltaMessageContent,
Action,
ToolAction,
)
from griptape.artifacts import TextArtifact, ActionArtifact
from griptape.drivers import BasePromptDriver
Expand Down Expand Up @@ -219,11 +219,11 @@ def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent
elif content.function_call:
function_call = content.function_call

name, path = Action.from_native_tool_name(function_call.name)
name, path = ToolAction.from_native_tool_name(function_call.name)

args = {k: v for k, v in function_call.args.items()}
return ActionCallMessageContent(
artifact=ActionArtifact(value=Action(tag=function_call.name, name=name, path=path, input=args))
artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args))
)
else:
raise ValueError(f"Unsupported message content type {content}")
Expand All @@ -234,7 +234,7 @@ def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMes
elif content.function_call:
function_call = content.function_call

name, path = Action.from_native_tool_name(function_call.name)
name, path = ToolAction.from_native_tool_name(function_call.name)

args = {k: v for k, v in function_call.args.items()}
return ActionCallDeltaMessageContent(
Expand Down
18 changes: 9 additions & 9 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
PromptStack,
Message,
TextMessageContent,
Action,
ToolAction,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer
Expand Down Expand Up @@ -153,7 +153,7 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
if message.to_text():
openai_messages.append({"role": message.role, "content": message.to_text()})
elif message.has_any_content_type(ActionResultMessageContent):
# Action results need to be expanded into separate messages.
# ToolAction results need to be expanded into separate messages.
openai_messages.extend(
[
{
Expand All @@ -165,7 +165,7 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
]
)
else:
# Action calls are attached to the assistant message that originally generated them.
# ToolAction calls are attached to the assistant message that originally generated them.
action_call_content = []
non_action_call_content = []
for content in message.content:
Expand All @@ -179,7 +179,7 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
"role": self.__to_openai_role(message),
"content": [
self.__to_openai_message_content(content)
for content in non_action_call_content # Action calls do not belong in the content
for content in non_action_call_content # ToolAction calls do not belong in the content
],
**(
{
Expand Down Expand Up @@ -251,10 +251,10 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
[
ActionCallMessageContent(
ActionArtifact(
Action(
ToolAction(
tag=tool_call.id,
name=Action.from_native_tool_name(tool_call.function.name)[0],
path=Action.from_native_tool_name(tool_call.function.name)[1],
name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
input=json.loads(tool_call.function.arguments),
)
)
Expand All @@ -280,8 +280,8 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
return ActionCallDeltaMessageContent(
index=index,
tag=tool_call.id,
name=Action.from_native_tool_name(tool_call.function.name)[0],
path=Action.from_native_tool_name(tool_call.function.name)[1],
name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
)
else:
return ActionCallDeltaMessageContent(index=index, partial_input=tool_call.function.arguments)
Expand Down
Loading

0 comments on commit d029a69

Please sign in to comment.