diff --git a/CHANGELOG.md b/CHANGELOG.md index b20d4c280..62b069cdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors. +- Structured Output support for all Prompt Drivers. +- `PromptTask.output_schema` for setting an output schema to be used with Structured Output. +- `Agent.output_schema` for setting an output schema to be used on the Agent's Prompt Task. +- `BasePromptDriver.structured_output_strategy` for changing the Structured Output strategy between `native`, `tool`, and `rule`. ## [1.1.1] - 2025-01-03 @@ -31,8 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Structure.run_stream()` for streaming Events from a Structure as an iterator. - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. -- `BasePromptDriver.use_structured_output` for enabling or disabling structured output. -- `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 22c3dd4ff..a6694726b 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -27,30 +27,27 @@ You can pass images to the Driver if the model supports it: ## Structured Output -Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. +Some LLMs provide functionality often referred to as "Structured Output". +This means instructing the LLM to output data in a particular format, usually JSON. +This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. -Structured output can be enabled or disabled for a Prompt Driver by setting the [use_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_structured_output). - -If `use_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: - -- `native`: The Driver will use the LLM's structured output functionality provided by the API. -- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool. - -Each Driver may have a different default setting depending on the LLM provider's capabilities. +!!! warning + Each Driver may have a different default setting depending on the LLM provider's capabilities. ### Prompt Task The easiest way to get started with structured output is by using a [PromptTask](../structures/tasks.md#prompt)'s [output_schema](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.PromptTask.output_schema) parameter. +You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of: + +- `native`: The Driver will use the LLM's structured output functionality provided by the API. +- `tool`: The Task will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and the Driver will try to force the LLM to use the Tool. +- `rule`: The Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort. + ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" ``` -If `use_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. - -!!! warning - Not every LLM supports `use_structured_output` or all `structured_output_strategy` options. - ## Prompt Drivers Griptape offers the following Prompt Drivers for interacting with LLMs. diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py index 8f5d0b77b..cb7eb5ceb 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -11,7 +11,6 @@ PromptTask( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", - use_structured_output=True, # optional structured_output_strategy="native", # optional ), output_schema=schema.Schema( diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index f4837bdeb..12ea13ad5 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any from attrs import Attribute, Factory, define, field from schema import Schema @@ -41,6 +41,7 @@ import boto3 from griptape.common import PromptStack + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -55,17 +56,16 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": - raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.") + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @@ -134,12 +134,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "toolChoice": self.tool_choice, } - if ( - prompt_stack.output_schema is not None - and self.use_structured_output - and self.structured_output_strategy == "tool" - ): - self._add_structured_output_tool_if_absent(prompt_stack) + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["toolConfig"]["toolChoice"] = {"any": {}} params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index d98ac9fd4..bc0e28266 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -20,6 +20,7 @@ import boto3 from griptape.common import PromptStack + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -39,8 +40,18 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): ), kw_only=True, ) + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} + ) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value != "rule": + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @lazy_property() def client(self) -> Any: return self.session.client("sagemaker-runtime") diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 22eaf0d30..9a558e7cf 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Optional from attrs import Attribute, Factory, define, field from schema import Schema @@ -42,6 +42,7 @@ from anthropic import Client from anthropic.types import ContentBlock, ContentBlockDeltaEvent, ContentBlockStartEvent + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools.base_tool import BaseTool @@ -68,8 +69,7 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) @@ -80,9 +80,9 @@ def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": - raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.") + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @@ -136,12 +136,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice - if ( - prompt_stack.output_schema is not None - and self.use_structured_output - and self.structured_output_strategy == "tool" - ): - self._add_structured_output_tool_if_absent(prompt_stack) + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_choice"] = {"type": "any"} params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index eb00adee4..c5ffb7259 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -32,6 +32,8 @@ from griptape.tokenizers import BaseTokenizer +StructuredOutputStrategy = Literal["native", "tool", "rule"] + @define(kw_only=True) class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): @@ -56,9 +58,8 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( - default="native", kw_only=True, metadata={"serializable": True} + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) @@ -126,16 +127,6 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... - def _add_structured_output_tool_if_absent(self, prompt_stack: PromptStack) -> None: - from griptape.tools.structured_output.tool import StructuredOutputTool - - if prompt_stack.output_schema is None: - raise ValueError("PromptStack must have an output schema to use structured output.") - - structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - if structured_output_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_output_tool) - def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 4810aad65..9158c4ad1 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,7 +53,6 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -112,15 +111,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_structured_output: - if self.structured_output_strategy == "native": - params["response_format"] = { - "type": "json_object", - "schema": prompt_stack.output_schema.json_schema("Output"), - } - elif self.structured_output_strategy == "tool": - # TODO: Implement tool choice once supported - self._add_structured_output_tool_if_absent(prompt_stack) + if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": + params["response_format"] = { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_cohere_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index cb7ac47b5..46a721b08 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -2,7 +2,7 @@ import json import logging -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Optional from attrs import Attribute, Factory, define, field from schema import Schema @@ -37,6 +37,7 @@ from google.generativeai.protos import Part from google.generativeai.types import ContentDict, ContentsType, GenerateContentResponse + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -63,17 +64,16 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": - raise ValueError("GooglePromptDriver does not support `native` structured output mode.") + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @@ -164,13 +164,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} - if ( - prompt_stack.output_schema is not None - and self.use_structured_output - and self.structured_output_strategy == "tool" - ): + if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_config"]["function_calling_config"]["mode"] = "auto" - self._add_structured_output_tool_if_absent(prompt_stack) params["tools"] = self.__to_google_tools(prompt_stack.tools) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index e0a35048f..57a487450 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from attrs import Attribute, Factory, define, field @@ -17,6 +17,8 @@ from huggingface_hub import InferenceClient + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy + logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -35,8 +37,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - structured_output_strategy: Literal["native", "tool"] = field( + structured_output_strategy: StructuredOutputStrategy = field( default="native", kw_only=True, metadata={"serializable": True} ) tokenizer: HuggingFaceTokenizer = field( @@ -56,9 +57,9 @@ def client(self) -> InferenceClient: ) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, attribute: Attribute, value: str) -> str: + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "tool": - raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.") + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @@ -121,7 +122,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema and self.use_structured_output and self.structured_output_strategy == "native": + if prompt_stack.output_schema and self.structured_output_strategy == "native": # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.output_schema.json_schema("Output Schema") # Grammar does not support $schema and $id diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index a197523df..866f033ec 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -3,7 +3,7 @@ import logging from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts import TextArtifact from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent, observable @@ -18,6 +18,8 @@ from transformers import TextGenerationPipeline + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy + logger = logging.getLogger(Defaults.logging_config.logger_name) @@ -38,10 +40,20 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ), kw_only=True, ) + structured_output_strategy: StructuredOutputStrategy = field( + default="rule", kw_only=True, metadata={"serializable": True} + ) _pipeline: TextGenerationPipeline = field( default=None, kw_only=True, alias="pipeline", metadata={"serializable": False} ) + @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: + if value in ("native", "tool"): + raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") + + return value + @lazy_property() def pipeline(self) -> TextGenerationPipeline: return import_optional_dependency("transformers").pipeline( diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index fd3b24524..1c4ae3fd1 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,7 +68,6 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -110,12 +109,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_structured_output: - if self.structured_output_strategy == "native": - params["format"] = prompt_stack.output_schema.json_schema("Output") - elif self.structured_output_strategy == "tool": - # TODO: Implement tool choice once supported - self._add_structured_output_tool_if_absent(prompt_stack) + if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": + params["format"] = prompt_stack.output_schema.json_schema("Output") # Tool calling is only supported when not streaming if prompt_stack.tools and self.use_native_tools and not self.stream: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 5a1029eee..03390d687 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -35,6 +35,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.chat.chat_completion_message import ChatCompletionMessage + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool @@ -76,7 +77,9 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - use_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + structured_output_strategy: StructuredOutputStrategy = field( + default="native", kw_only=True, metadata={"serializable": True} + ) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -159,7 +162,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["tool_choice"] = self.tool_choice params["parallel_tool_calls"] = self.parallel_tool_calls - if prompt_stack.output_schema is not None and self.use_structured_output: + if prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", @@ -171,7 +174,6 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: } elif self.structured_output_strategy == "tool" and self.use_native_tools: params["tool_choice"] = "required" - self._add_structured_output_tool_if_absent(prompt_stack) if self.response_format is not None: if self.response_format == {"type": "json_object"}: diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 4432c1080..fa622bd05 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -172,6 +172,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: BaseTextToSpeechDriver, BaseVectorStoreDriver, ) + from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.events import EventListener from griptape.memory import TaskMemory from griptape.memory.structure import BaseConversationMemory, Run @@ -216,6 +217,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseArtifactStorage": BaseArtifactStorage, "BaseRule": BaseRule, "Ruleset": Ruleset, + "StructuredOutputStrategy": StructuredOutputStrategy, # Third party modules "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any, diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index baf36108f..9b70b7fb1 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -12,6 +12,8 @@ from griptape.tasks import PromptTask if TYPE_CHECKING: + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.drivers import BasePromptDriver from griptape.tasks import BaseTask @@ -25,6 +27,7 @@ class Agent(Structure): ) stream: bool = field(default=None, kw_only=True) prompt_driver: BasePromptDriver = field(default=None, kw_only=True) + output_schema: Optional[Schema] = field(default=None, kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) @@ -98,6 +101,7 @@ def _init_task(self) -> None: self.input, prompt_driver=self.prompt_driver, tools=self.tools, + output_schema=self.output_schema, max_meta_memory_entries=self.max_meta_memory_entries, ) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index ae80effcb..15c0f7457 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -15,6 +15,7 @@ from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.mixins.rule_mixin import RuleMixin from griptape.rules import Ruleset +from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.tasks import ActionsSubtask, BaseTask from griptape.utils import J2 @@ -91,9 +92,16 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: - stack = PromptStack(tools=self.tools, output_schema=self.output_schema) + from griptape.tools.structured_output.tool import StructuredOutputTool + + stack = PromptStack(tools=self.tools) memory = self.structure.conversation_memory if self.structure is not None else None + if self.output_schema is not None: + stack.output_schema = self.output_schema + if self.prompt_driver.structured_output_strategy == "tool": + stack.tools.append(StructuredOutputTool(output_schema=stack.output_schema)) + system_template = self.generate_system_template(self) if system_template: stack.add_system_message(system_template) @@ -190,7 +198,7 @@ def try_run(self) -> BaseArtifact: else: output = result.to_artifact() - if self.prompt_driver.use_structured_output and self.prompt_driver.structured_output_strategy == "native": + if self.output_schema is not None and self.prompt_driver.structured_output_strategy in ("native", "rule"): return JsonArtifact(output.value) else: return output @@ -210,8 +218,6 @@ def preprocess(self, structure: Structure) -> BaseTask: return self def default_generate_system_template(self, _: PromptTask) -> str: - from griptape.rules import JsonSchemaRule - schema = self.actions_schema().json_schema("Actions Schema") schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. @@ -221,8 +227,8 @@ def default_generate_system_template(self, _: PromptTask) -> str: actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, - use_structured_output=self.prompt_driver.use_structured_output, - json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output Schema")) + structured_output_strategy=self.prompt_driver.structured_output_strategy, + json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output")) if self.output_schema is not None else None, stop_sequence=self.response_stop_sequence, diff --git a/griptape/templates/tasks/prompt_task/system.j2 b/griptape/templates/tasks/prompt_task/system.j2 index e1a8bb21b..8e89e13c7 100644 --- a/griptape/templates/tasks/prompt_task/system.j2 +++ b/griptape/templates/tasks/prompt_task/system.j2 @@ -26,7 +26,7 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER {{ rulesets }} {% endif %} -{% if not use_structured_output and json_schema_rule %} +{% if json_schema_rule and structured_output_strategy == 'rule' %} {{ json_schema_rule }} {% endif %} diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 243b29281..3310a952e 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,15 +36,6 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_structured_output and prompt_stack.output_schema is not None: - if self.structured_output_strategy == "native": - return Message( - content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], - role=Message.ASSISTANT_ROLE, - usage=Message.Usage(input_tokens=100, output_tokens=100), - ) - elif self.structured_output_strategy == "tool": - self._add_structured_output_tool(prompt_stack) if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. @@ -58,41 +49,42 @@ def try_run(self, prompt_stack: PromptStack) -> Message: usage=Message.Usage(input_tokens=100, output_tokens=100), ) else: + if self.structured_output_strategy == "tool": + tool_action = ToolAction( + tag="mock-tag", + name="StructuredOutputTool", + path="provide_output", + input={"values": self.mock_structured_output}, + ) + else: + tool_action = ToolAction( + tag="mock-tag", + name="MockTool", + path="test", + input={"values": {"test": "test-value"}}, + ) + return Message( - content=[ - ActionCallMessageContent( - ActionArtifact( - ToolAction( - tag="mock-tag", - name="MockTool", - path="test", - input={"values": {"test": "test-value"}}, - ) - ) - ) - ], + content=[ActionCallMessageContent(ActionArtifact(tool_action))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=100, output_tokens=100), ) else: - return Message( - content=[TextMessageContent(TextArtifact(output))], - role=Message.ASSISTANT_ROLE, - usage=Message.Usage(input_tokens=100, output_tokens=100), - ) - - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: - output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - - if self.use_structured_output and prompt_stack.output_schema is not None: - if self.structured_output_strategy == "native": - yield DeltaMessage( - content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), + if prompt_stack.output_schema is not None: + return Message( + content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), + ) + else: + return Message( + content=[TextMessageContent(TextArtifact(output))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=100, output_tokens=100), ) - elif self.structured_output_strategy == "tool": - self._add_structured_output_tool(prompt_stack) + + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. @@ -103,15 +95,36 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: yield DeltaMessage(content=TextDeltaMessageContent(f"Answer: {output}")) yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100)) else: - yield DeltaMessage( - content=ActionCallDeltaMessageContent( - tag="mock-tag", - name="MockTool", - path="test", + if self.structured_output_strategy == "tool": + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + tag="mock-tag", + name="StructuredOutputTool", + path="provide_output", + ) ) - ) + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + partial_input=json.dumps({"values": self.mock_structured_output}) + ) + ) + else: + yield DeltaMessage( + content=ActionCallDeltaMessageContent( + tag="mock-tag", + name="MockTool", + path="test", + ) + ) + yield DeltaMessage( + content=ActionCallDeltaMessageContent(partial_input='{ "values": { "test": "test-value" } }') + ) + else: + if prompt_stack.output_schema is not None: yield DeltaMessage( - content=ActionCallDeltaMessageContent(partial_input='{ "values": { "test": "test-value" } }') + content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), ) - else: - yield DeltaMessage(content=TextDeltaMessageContent(output)) + else: + yield DeltaMessage(content=TextDeltaMessageContent(output)) diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index d9a4f4cb3..b2fd51d24 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,7 +51,6 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, - "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, @@ -108,7 +107,6 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, - "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 1df66b534..fa13480c1 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -26,7 +26,6 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, "structured_output_strategy": "tool", - "use_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index c63f8bdbc..a30cea001 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -37,7 +37,6 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", - "use_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 11a39ba4c..d5e05c9bd 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,8 +26,7 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, - "use_structured_output": True, - "structured_output_strategy": "native", + "structured_output_strategy": "rule", "extra_params": {}, }, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index fa8c07c8c..5adec7c6d 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,8 +18,7 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, - "use_structured_output": False, - "structured_output_strategy": "native", + "structured_output_strategy": "rule", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 1f53ae59f..910ae3240 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,7 +25,6 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, - "use_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index a77f9ab46..344d14d99 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -29,7 +29,6 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", - "use_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index d7e642b39..2dcb4bf02 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -8,29 +8,6 @@ class TestAmazonBedrockPromptDriver: - BEDROCK_STRUCTURED_OUTPUT_TOOL = { - "toolSpec": { - "description": "Used to provide the final response which ends this conversation.", - "inputSchema": { - "json": { - "$id": "http://json-schema.org/draft-07/schema#", - "$schema": "http://json-schema.org/draft-07/schema#", - "additionalProperties": False, - "properties": { - "values": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "required": ["values"], - "type": "object", - }, - }, - "name": "StructuredOutputTool_provide_output", - }, - } BEDROCK_TOOLS = [ { "toolSpec": { @@ -384,13 +361,13 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -410,15 +387,10 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, **( { "toolConfig": { - "tools": [ - *self.BEDROCK_TOOLS, - *( - [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and driver.structured_output_strategy == "tool" - else [] - ), - ], - "toolChoice": {"any": {}} if use_structured_output else driver.tool_choice, + "tools": self.BEDROCK_TOOLS, + "toolChoice": {"any": {}} + if driver.structured_output_strategy == "tool" + else driver.tool_choice, } } if use_native_tools @@ -437,16 +409,16 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) def test_try_stream_run( - self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_structured_output + self, mock_converse_stream, prompt_stack, messages, use_native_tools, structured_output_strategy ): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -467,15 +439,10 @@ def test_try_stream_run( **( { "toolConfig": { - "tools": [ - *self.BEDROCK_TOOLS, - *( - [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and driver.structured_output_strategy == "tool" - else [] - ), - ], - "toolChoice": {"any": {}} if use_structured_output else driver.tool_choice, + "tools": self.BEDROCK_TOOLS, + "toolChoice": {"any": {}} + if driver.structured_output_strategy == "tool" + else driver.tool_choice, } } if use_native_tools @@ -506,6 +473,6 @@ def test_verify_structured_output_strategy(self): assert AmazonBedrockPromptDriver(model="foo", structured_output_strategy="tool") with pytest.raises( - ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output mode." + ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output strategy." ): AmazonBedrockPromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index c7b0682c2..7b2d38398 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -138,3 +138,12 @@ def test_try_run_throws_on_empty_response(self, mock_client): # Then assert e.value.args[0] == "model response is empty" + + def test_verify_structured_output_strategy(self): + assert AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="foo", structured_output_strategy="rule") + + with pytest.raises( + ValueError, + match="AmazonSageMakerJumpstartPromptDriver does not support `native` structured output strategy.", + ): + AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 38b8c8bbb..fbdf1e55d 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -142,24 +142,6 @@ class TestAnthropicPromptDriver: }, ] - ANTHROPIC_STRUCTURED_OUTPUT_TOOL = { - "description": "Used to provide the final response which ends this conversation.", - "input_schema": { - "additionalProperties": False, - "properties": { - "values": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "required": ["values"], - "type": "object", - }, - "name": "StructuredOutputTool_provide_output", - } - @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") @@ -370,14 +352,14 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -395,17 +377,8 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, **{ - "tools": [ - *self.ANTHROPIC_TOOLS, - *( - [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and driver.structured_output_strategy == "tool" - else [] - ), - ] - if use_native_tools - else {}, - "tool_choice": {"type": "any"} if use_structured_output else driver.tool_choice, + "tools": self.ANTHROPIC_TOOLS if use_native_tools else {}, + "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -422,15 +395,17 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, us assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_stream_run( + self, mock_stream_client, prompt_stack, messages, use_native_tools, structured_output_strategy + ): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -450,17 +425,8 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, **{ - "tools": [ - *self.ANTHROPIC_TOOLS, - *( - [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and driver.structured_output_strategy == "tool" - else [] - ), - ] - if use_native_tools - else {}, - "tool_choice": {"type": "any"} if use_structured_output else driver.tool_choice, + "tools": self.ANTHROPIC_TOOLS if use_native_tools else {}, + "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -492,5 +458,7 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na def test_verify_structured_output_strategy(self): assert AnthropicPromptDriver(model="foo", structured_output_strategy="tool") - with pytest.raises(ValueError, match="AnthropicPromptDriver does not support `native` structured output mode."): + with pytest.raises( + ValueError, match="AnthropicPromptDriver does not support `native` structured output strategy." + ): AnthropicPromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index d97c16ba3..8f0da735a 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,7 +67,6 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_run( self, @@ -75,7 +74,6 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -84,7 +82,6 @@ def test_try_run( azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -99,17 +96,8 @@ def test_try_run( user=driver.user, messages=messages, **{ - "tools": [ - *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ], - "tool_choice": "required" - if use_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -123,7 +111,7 @@ def test_try_run( }, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) @@ -136,7 +124,6 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_stream_run( self, @@ -144,7 +131,6 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -154,7 +140,6 @@ def test_try_stream_run( model="gpt-4", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -171,17 +156,8 @@ def test_try_stream_run( stream=True, messages=messages, **{ - "tools": [ - *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ], - "tool_choice": "required" - if use_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, } if use_native_tools else {}, @@ -195,7 +171,7 @@ def test_try_stream_run( }, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 985cc3d31..58720bbc5 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,5 +1,3 @@ -import pytest - from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent @@ -67,24 +65,3 @@ def test_run_with_tools_and_stream(self, mock_config): output = pipeline.run().output_task.output assert isinstance(output, TextArtifact) assert output.value == "mock output" - - def test__add_structured_output_tool(self): - from schema import Schema - - from griptape.tools.structured_output.tool import StructuredOutputTool - - mock_prompt_driver = MockPromptDriver() - - prompt_stack = PromptStack() - - with pytest.raises(ValueError, match="PromptStack must have an output schema to use structured output."): - mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) - - prompt_stack.output_schema = Schema({"foo": str}) - - mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) - # Ensure it doesn't get added twice - mock_prompt_driver._add_structured_output_tool_if_absent(prompt_stack) - assert len(prompt_stack.tools) == 1 - assert isinstance(prompt_stack.tools[0], StructuredOutputTool) - assert prompt_stack.tools[0].output_schema is prompt_stack.output_schema diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 858aa5bee..8b51940c8 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -21,28 +21,6 @@ class TestCoherePromptDriver: "required": ["foo"], "type": "object", } - COHERE_STRUCTURED_OUTPUT_TOOL = { - "function": { - "description": "Used to provide the final response which ends this conversation.", - "name": "StructuredOutputTool_provide_output", - "parameters": { - "$id": "Parameters Schema", - "$schema": "http://json-schema.org/draft-07/schema#", - "additionalProperties": False, - "properties": { - "values": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "required": ["values"], - "type": "object", - }, - }, - "type": "function", - } COHERE_TOOLS = [ { "function": { @@ -338,7 +316,6 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -346,7 +323,6 @@ def test_try_run( prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -354,7 +330,6 @@ def test_try_run( model="command", api_key="api-key", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -367,25 +342,14 @@ def test_try_run( model="command", messages=messages, max_tokens=None, - **{ - "tools": [ - *self.COHERE_TOOLS, - *( - [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ] - } - if use_native_tools - else {}, + **{"tools": self.COHERE_TOOLS} if use_native_tools else {}, **{ "response_format": { "type": "json_object", "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, @@ -406,7 +370,6 @@ def test_try_run( assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -414,7 +377,6 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -423,7 +385,6 @@ def test_try_stream_run( api_key="api-key", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -437,25 +398,14 @@ def test_try_stream_run( model="command", messages=messages, max_tokens=None, - **{ - "tools": [ - *self.COHERE_TOOLS, - *( - [self.COHERE_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ] - } - if use_native_tools - else {}, + **{"tools": self.COHERE_TOOLS} if use_native_tools else {}, **{ "response_format": { "type": "json_object", "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 53c33735e..aacc207b9 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -14,15 +14,6 @@ class TestGooglePromptDriver: - GOOGLE_STRUCTURED_OUTPUT_TOOL = { - "description": "Used to provide the final response which ends this conversation.", - "name": "StructuredOutputTool_provide_output", - "parameters": { - "properties": {"foo": {"type": "STRING"}}, - "required": ["foo"], - "type": "OBJECT", - }, - } GOOGLE_TOOLS = [ { "name": "MockTool_test", @@ -177,8 +168,8 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) + def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools, structured_output_strategy): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -186,8 +177,7 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native top_p=0.5, top_k=50, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, - structured_output_strategy="tool", + structured_output_strategy=structured_output_strategy, extra_params={"max_output_tokens": 10}, ) @@ -209,13 +199,10 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - tools = [ - *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_structured_output else []), - ] + tools = self.GOOGLE_TOOLS assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if use_structured_output: + if driver.structured_output_strategy == "tool": assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) @@ -229,9 +216,9 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["tool", "rule", "foo"]) def test_try_stream( - self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_structured_output + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, structured_output_strategy ): # Given driver = GooglePromptDriver( @@ -241,7 +228,7 @@ def test_try_stream( top_p=0.5, top_k=50, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, + structured_output_strategy=structured_output_strategy, extra_params={"max_output_tokens": 10}, ) @@ -265,13 +252,10 @@ def test_try_stream( ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - tools = [ - *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_structured_output else []), - ] + tools = self.GOOGLE_TOOLS assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if use_structured_output: + if driver.structured_output_strategy == "tool": assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" @@ -291,5 +275,7 @@ def test_try_stream( def test_verify_structured_output_strategy(self): assert GooglePromptDriver(model="foo", structured_output_strategy="tool") - with pytest.raises(ValueError, match="GooglePromptDriver does not support `native` structured output mode."): + with pytest.raises( + ValueError, match="GooglePromptDriver does not support `native` structured output strategy." + ): GooglePromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 763a4f7b1..b757dbcea 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -54,14 +54,14 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_run(self, prompt_stack, mock_client, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule", "foo"]) + def test_try_run(self, prompt_stack, mock_client, structured_output_strategy): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", - use_structured_output=use_structured_output, extra_params={"foo": "bar"}, + structured_output_strategy=structured_output_strategy, ) # When @@ -73,23 +73,27 @@ def test_try_run(self, prompt_stack, mock_client, use_structured_output): return_full_text=False, max_new_tokens=250, foo="bar", - **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} - if use_structured_output - else {}, + **( + { + "grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + } + if structured_output_strategy == "native" + else {} + ), ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - @pytest.mark.parametrize("use_structured_output", [True, False]) - def test_try_stream(self, prompt_stack, mock_client_stream, use_structured_output): + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule", "foo"]) + def test_try_stream(self, prompt_stack, mock_client_stream, structured_output_strategy): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", stream=True, - use_structured_output=use_structured_output, extra_params={"foo": "bar"}, + structured_output_strategy=structured_output_strategy, ) # When @@ -102,9 +106,13 @@ def test_try_stream(self, prompt_stack, mock_client_stream, use_structured_outpu return_full_text=False, max_new_tokens=250, foo="bar", - **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} - if use_structured_output - else {}, + **( + { + "grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + } + if structured_output_strategy == "native" + else {} + ), stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) @@ -118,6 +126,6 @@ def test_verify_structured_output_strategy(self): assert HuggingFaceHubPromptDriver(model="foo", api_token="bar", structured_output_strategy="native") with pytest.raises( - ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output mode." + ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output strategy." ): HuggingFaceHubPromptDriver(model="foo", api_token="bar", structured_output_strategy="tool") diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index af52ca4e9..e03604aaf 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -42,10 +42,15 @@ def messages(self): def test_init(self, mock_pipeline): assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42, pipeline=mock_pipeline) - def test_try_run(self, prompt_stack, messages, mock_pipeline): + @pytest.mark.parametrize("structured_output_strategy", ["rule", "foo"]) + def test_try_run(self, prompt_stack, messages, mock_pipeline, structured_output_strategy): # Given driver = HuggingFacePipelinePromptDriver( - model="foo", max_tokens=42, extra_params={"foo": "bar"}, pipeline=mock_pipeline + model="foo", + max_tokens=42, + extra_params={"foo": "bar"}, + pipeline=mock_pipeline, + structured_output_strategy=structured_output_strategy, ) # When @@ -57,9 +62,12 @@ def test_try_run(self, prompt_stack, messages, mock_pipeline): assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_pipeline): + @pytest.mark.parametrize("structured_output_strategy", ["rule", "foo"]) + def test_try_stream(self, prompt_stack, mock_pipeline, structured_output_strategy): # Given - driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42, pipeline=mock_pipeline) + driver = HuggingFacePipelinePromptDriver( + model="foo", max_tokens=42, pipeline=mock_pipeline, structured_output_strategy=structured_output_strategy + ) # When with pytest.raises(Exception) as e: @@ -101,3 +109,11 @@ def test_prompt_stack_to_string(self, prompt_stack, mock_pipeline): # Then assert result == "model-output" + + def test_verify_structured_output_strategy(self): + assert HuggingFacePipelinePromptDriver(model="foo", structured_output_strategy="rule") + + with pytest.raises( + ValueError, match="HuggingFacePipelinePromptDriver does not support `native` structured output strategy." + ): + HuggingFacePipelinePromptDriver(model="foo", structured_output_strategy="native") diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index d638e84e2..02f284b76 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -16,19 +16,6 @@ class TestOllamaPromptDriver: "required": ["foo"], "type": "object", } - OLLAMA_STRUCTURED_OUTPUT_TOOL = { - "function": { - "description": "Used to provide the final response which ends this conversation.", - "name": "StructuredOutputTool_provide_output", - "parameters": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "type": "function", - } OLLAMA_TOOLS = [ { "function": { @@ -232,22 +219,19 @@ def test_init(self): assert OllamaPromptDriver(model="llama") @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) def test_try_run( self, mock_client, prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given driver = OllamaPromptDriver( model="llama", use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -265,20 +249,11 @@ def test_try_run( "num_predict": driver.max_tokens, }, **{ - "tools": [ - *self.OLLAMA_TOOLS, - *( - [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ] + "tools": self.OLLAMA_TOOLS, } if use_native_tools else {}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_structured_output and structured_output_strategy == "native" - else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -290,15 +265,13 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) def test_try_stream_run( self, mock_stream_client, prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -306,7 +279,6 @@ def test_try_stream_run( model="llama", stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -319,9 +291,7 @@ def test_try_stream_run( messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} - if use_structured_output and structured_output_strategy == "native" - else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, stream=True, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index eff9fda66..496560529 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -20,28 +20,6 @@ class TestOpenAiChatPromptDriverFixtureMixin: "required": ["foo"], "type": "object", } - OPENAI_STRUCTURED_OUTPUT_TOOL = { - "function": { - "description": "Used to provide the final response which ends this conversation.", - "name": "StructuredOutputTool_provide_output", - "parameters": { - "$id": "Parameters Schema", - "$schema": "http://json-schema.org/draft-07/schema#", - "additionalProperties": False, - "properties": { - "values": { - "additionalProperties": False, - "properties": {"foo": {"type": "string"}}, - "required": ["foo"], - "type": "object", - }, - }, - "required": ["values"], - "type": "object", - }, - }, - "type": "function", - } OPENAI_TOOLS = [ { "function": { @@ -371,22 +349,19 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) def test_try_run( self, mock_chat_completion_create, prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -402,17 +377,8 @@ def test_try_run( messages=messages, seed=driver.seed, **{ - "tools": [ - *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ], - "tool_choice": "required" - if use_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -427,7 +393,7 @@ def test_try_run( }, } } - if use_structured_output and structured_output_strategy == "native" + if prompt_stack.output_schema is not None and structured_output_strategy == "native" else {}, foo="bar", ) @@ -509,15 +475,13 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - @pytest.mark.parametrize("use_structured_output", [True, False]) - @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) def test_try_stream_run( self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools, - use_structured_output, structured_output_strategy, ): # Given @@ -525,7 +489,6 @@ def test_try_stream_run( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, - use_structured_output=use_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -544,17 +507,8 @@ def test_try_stream_run( seed=driver.seed, stream_options={"include_usage": True}, **{ - "tools": [ - *self.OPENAI_TOOLS, - *( - [self.OPENAI_STRUCTURED_OUTPUT_TOOL] - if use_structured_output and structured_output_strategy == "tool" - else [] - ), - ], - "tool_choice": "required" - if use_structured_output and structured_output_strategy == "tool" - else driver.tool_choice, + "tools": self.OPENAI_TOOLS, + "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -569,7 +523,7 @@ def test_try_stream_run( }, } } - if use_structured_output and structured_output_strategy == "native" + if structured_output_strategy == "native" else {}, foo="bar", ) @@ -596,11 +550,11 @@ def test_try_stream_run( def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given + prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False, - use_structured_output=False, ) # When @@ -630,12 +584,12 @@ def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completio assert e.value.args[0] == "Completion with more than one choice is not supported yet." def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): + prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, use_native_tools=False, - use_structured_output=False, ) # When diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 809d174b5..442f654d5 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +import schema from griptape.memory import TaskMemory from griptape.memory.structure import ConversationMemory @@ -316,3 +317,14 @@ def test_field_hierarchy(self): assert isinstance(agent.tasks[0], PromptTask) assert agent.tasks[0].prompt_driver.stream is True + + def test_output_schema(self): + agent = Agent() + + assert isinstance(agent.tasks[0], PromptTask) + assert agent.tasks[0].output_schema is None + + agent = Agent(output_schema=schema.Schema({"foo": str})) + + assert isinstance(agent.tasks[0], PromptTask) + assert agent.tasks[0].output_schema is agent.output_schema diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 807e78f0b..da277e81e 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -83,8 +83,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, - "use_structured_output": False, - "structured_output_strategy": "native", + "structured_output_strategy": "rule", }, } ], diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 60a10f1a4..2cd102bf8 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,5 +1,3 @@ -import warnings - from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.json_artifact import JsonArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -183,8 +181,8 @@ def test_prompt_stack_native_schema(self): task = PromptTask( input="foo", prompt_driver=MockPromptDriver( - use_structured_output=True, mock_structured_output={"baz": "foo"}, + structured_output_strategy="native", ), output_schema=output_schema, ) @@ -197,17 +195,33 @@ def test_prompt_stack_native_schema(self): assert task.prompt_stack.messages[0].is_user() assert "foo" in task.prompt_stack.messages[0].to_text() - # Ensure no warnings were raised - with warnings.catch_warnings(): - warnings.simplefilter("error") - assert task.prompt_stack + def test_prompt_stack_tool_schema(self): + from schema import Schema - def test_prompt_stack_empty_native_schema(self): + output_schema = Schema({"baz": str}) task = PromptTask( input="foo", prompt_driver=MockPromptDriver( - use_structured_output=True, + mock_structured_output={"baz": "foo"}, + structured_output_strategy="tool", + use_native_tools=True, ), + output_schema=output_schema, + ) + output = task.run() + + assert isinstance(output, JsonArtifact) + assert output.value == {"baz": "foo"} + + assert task.prompt_stack.output_schema is output_schema + assert task.prompt_stack.messages[0].is_system() + assert task.prompt_stack.messages[1].is_user() + assert "foo" in task.prompt_stack.messages[1].to_text() + + def test_prompt_stack_empty_native_schema(self): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver(), rules=[JsonSchemaRule({"foo": {}})], ) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 00bbadc45..5c7f6b394 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -257,8 +257,7 @@ def test_to_dict(self): "stream": False, "temperature": 0.1, "type": "MockPromptDriver", - "structured_output_strategy": "native", - "use_structured_output": False, + "structured_output_strategy": "rule", "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 70c59e1f8..a5e95f4d1 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,8 +399,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, - "use_structured_output": False, - "structured_output_strategy": "native", + "structured_output_strategy": "rule", }, "tools": [ {