Skip to content

Commit

Permalink
Add Structured Output functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 24, 2024
1 parent af1e5eb commit a905f83
Show file tree
Hide file tree
Showing 34 changed files with 835 additions and 153 deletions.
5 changes: 4 additions & 1 deletion griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from attrs import define, field

Expand All @@ -24,13 +24,16 @@
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from schema import Schema

from griptape.tools import BaseTool


@define
class PromptStack(SerializableMixin):
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
tools: list[BaseTool] = field(factory=list, kw_only=True)
output_schema: Optional[Schema] = field(default=None, kw_only=True)

@property
def system_messages(self) -> list[Message]:
Expand Down
41 changes: 32 additions & 9 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -55,9 +55,20 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.")

return value

@lazy_property()
def client(self) -> Any:
return self.session.client("bedrock-runtime")
Expand Down Expand Up @@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:

def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]

messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

return {
params = {
"modelId": self.model,
"messages": messages,
"system": system_messages,
Expand All @@ -115,14 +125,27 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
},
"additionalModelRequestFields": self.additional_model_request_fields,
**(
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

if prompt_stack.tools and self.use_native_tools:
params["toolConfig"] = {
"tools": [],
"toolChoice": self.tool_choice,
}

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_mode == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["toolConfig"]["toolChoice"] = {"any": {}}

params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)

return params

def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
return [
{
Expand Down
37 changes: 29 additions & 8 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -68,13 +68,24 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Client:
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

@native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.")

return value

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
params = self._base_params(prompt_stack)
Expand Down Expand Up @@ -110,23 +121,33 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = prompt_stack.system_messages
system_message = system_messages[0].to_text() if system_messages else None

return {
params = {
"model": self.model,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"messages": messages,
**(
{"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice}
if prompt_stack.tools and self.use_native_tools
else {}
),
**({"system": system_message} if system_message else {}),
**self.extra_params,
}

if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_mode == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["tool_choice"] = {"type": "any"}

params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)

return params

def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
return [
{"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
Expand Down
16 changes: 15 additions & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field

Expand Down Expand Up @@ -56,6 +56,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
Expand Down Expand Up @@ -122,6 +126,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

def _add_structured_output_tool(self, prompt_stack: PromptStack) -> None:
from griptape.tools.structured_output.tool import StructuredOutputTool

if prompt_stack.output_schema is None:
raise ValueError("PromptStack must have an output schema to use structured output.")

structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)

def __process_run(self, prompt_stack: PromptStack) -> Message:
return self.try_run(prompt_stack)

Expand Down
18 changes: 17 additions & 1 deletion griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver):
model: str = field(metadata={"serializable": True})
force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
Expand Down Expand Up @@ -101,7 +102,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:

messages = self.__to_cohere_messages(prompt_stack.messages)

return {
params = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
Expand All @@ -116,6 +117,21 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_mode == "native":
params["response_format"] = {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}
elif self.native_structured_output_mode == "tool":
# TODO: Implement tool choice once supported
self._add_structured_output_tool(prompt_stack)

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_cohere_tools(prompt_stack.tools)

return params

def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
cohere_messages = []

Expand Down
40 changes: 29 additions & 11 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import json
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import ActionArtifact, TextArtifact
Expand Down Expand Up @@ -63,9 +63,20 @@ class GooglePromptDriver(BasePromptDriver):
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
_client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("GooglePromptDriver does not support `native` structured output mode.")

return value

@lazy_property()
def client(self) -> GenerativeModel:
genai = import_optional_dependency("google.generativeai")
Expand Down Expand Up @@ -135,7 +146,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages],
)

return {
params = {
"generation_config": types.GenerationConfig(
**{
# For some reason, providing stop sequences when streaming breaks native functions
Expand All @@ -148,16 +159,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
},
),
**(
{
"tools": self.__to_google_tools(prompt_stack.tools),
"tool_config": {"function_calling_config": {"mode": self.tool_choice}},
}
if prompt_stack.tools and self.use_native_tools
else {}
),
}

if prompt_stack.tools and self.use_native_tools:
params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_mode == "tool"
):
params["tool_config"]["function_calling_config"]["mode"] = "auto"
self._add_structured_output_tool(prompt_stack)

params["tools"] = self.__to_google_tools(prompt_stack.tools)

return params

def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
types = import_optional_dependency("google.generativeai.types")

Expand Down
Loading

0 comments on commit a905f83

Please sign in to comment.