diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index a6694726b..e45d28d71 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -41,8 +41,8 @@ The easiest way to get started with structured output is by using a [PromptTask] You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of: - `native`: The Driver will use the LLM's structured output functionality provided by the API. -- `tool`: The Task will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and the Driver will try to force the LLM to use the Tool. -- `rule`: The Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort. +- `tool`: The Driver will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and will try to force the LLM to use the Tool. +- `rule`: The Driver will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort. ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index c5ffb7259..a6d769021 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field -from griptape.artifacts.base_artifact import BaseArtifact +from griptape.artifacts import BaseArtifact, TextArtifact from griptape.common import ( ActionCallDeltaMessageContent, ActionCallMessageContent, @@ -26,6 +26,7 @@ ) from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.serializable_mixin import SerializableMixin +from griptape.rules.json_schema_rule import JsonSchemaRule if TYPE_CHECKING: from collections.abc import Iterator @@ -64,6 +65,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: + self._init_structured_output(prompt_stack) EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: @@ -127,6 +129,34 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... + def _init_structured_output(self, prompt_stack: PromptStack) -> None: + from griptape.tools import StructuredOutputTool + + if (output_schema := prompt_stack.output_schema) is not None: + if self.structured_output_strategy == "tool": + structured_output_tool = StructuredOutputTool(output_schema=output_schema) + if structured_output_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_output_tool) + elif self.structured_output_strategy == "rule": + output_artifact = TextArtifact(JsonSchemaRule(output_schema.json_schema("Output Schema")).to_text()) + system_messages = prompt_stack.system_messages + if system_messages: + last_system_message = prompt_stack.system_messages[-1] + last_system_message.content.extend( + [ + TextMessageContent(TextArtifact("\n\n")), + TextMessageContent(output_artifact), + ] + ) + else: + prompt_stack.messages.insert( + 0, + Message( + content=[TextMessageContent(output_artifact)], + role=Message.SYSTEM_ROLE, + ), + ) + def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 15c0f7457..b70b1eac7 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -15,7 +15,6 @@ from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.mixins.rule_mixin import RuleMixin from griptape.rules import Ruleset -from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.tasks import ActionsSubtask, BaseTask from griptape.utils import J2 @@ -92,16 +91,9 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: - from griptape.tools.structured_output.tool import StructuredOutputTool - - stack = PromptStack(tools=self.tools) + stack = PromptStack(tools=self.tools, output_schema=self.output_schema) memory = self.structure.conversation_memory if self.structure is not None else None - if self.output_schema is not None: - stack.output_schema = self.output_schema - if self.prompt_driver.structured_output_strategy == "tool": - stack.tools.append(StructuredOutputTool(output_schema=stack.output_schema)) - system_template = self.generate_system_template(self) if system_template: stack.add_system_message(system_template) @@ -227,10 +219,6 @@ def default_generate_system_template(self, _: PromptTask) -> str: actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, - structured_output_strategy=self.prompt_driver.structured_output_strategy, - json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output")) - if self.output_schema is not None - else None, stop_sequence=self.response_stop_sequence, ) diff --git a/griptape/templates/tasks/prompt_task/system.j2 b/griptape/templates/tasks/prompt_task/system.j2 index 8e89e13c7..b262e7c72 100644 --- a/griptape/templates/tasks/prompt_task/system.j2 +++ b/griptape/templates/tasks/prompt_task/system.j2 @@ -26,7 +26,3 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER {{ rulesets }} {% endif %} -{% if json_schema_rule and structured_output_strategy == 'rule' %} - -{{ json_schema_rule }} -{% endif %} diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 3310a952e..1b481067b 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,7 +36,6 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. action_messages = [ @@ -85,7 +84,6 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if self.use_native_tools and prompt_stack.tools: # Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer. action_messages = [ diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 58720bbc5..3ffcebce4 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,9 +1,12 @@ -from griptape.artifacts import ErrorArtifact, TextArtifact +import json + +from griptape.artifacts import ActionArtifact, ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask +from griptape.tools.structured_output.tool import StructuredOutputTool from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -65,3 +68,74 @@ def test_run_with_tools_and_stream(self, mock_config): output = pipeline.run().output_task.output assert isinstance(output, TextArtifact) assert output.value == "mock output" + + def test_native_structured_output_strategy(self): + from schema import Schema + + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="native", + ) + + output_schema = Schema({"baz": str}) + output = prompt_driver.run(PromptStack(messages=[], output_schema=output_schema)).to_artifact() + + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) + + def test_tool_structured_output_strategy(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="tool", + use_native_tools=True, + ) + prompt_stack = PromptStack(messages=[], output_schema=output_schema) + output = prompt_driver.run(prompt_stack).to_artifact() + output = prompt_driver.run(prompt_stack).to_artifact() + + assert isinstance(output, ActionArtifact) + assert isinstance(prompt_stack.tools[0], StructuredOutputTool) + assert prompt_stack.tools[0].output_schema == output_schema + assert output.value.input == {"values": {"baz": "foo"}} + + def test_rule_structured_output_strategy_empty(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="rule", + ) + prompt_stack = PromptStack(messages=[], output_schema=output_schema) + output = prompt_driver.run(prompt_stack).to_artifact() + + assert len(prompt_stack.system_messages) == 1 + assert prompt_stack.messages[0].is_system() + assert "baz" in prompt_stack.messages[0].content[0].to_text() + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) + + def test_rule_structured_output_strategy_populated(self): + from schema import Schema + + output_schema = Schema({"baz": str}) + prompt_driver = MockPromptDriver( + mock_structured_output={"baz": "foo"}, + structured_output_strategy="rule", + ) + prompt_stack = PromptStack( + messages=[ + Message(content="foo", role=Message.SYSTEM_ROLE), + ], + output_schema=output_schema, + ) + output = prompt_driver.run(prompt_stack).to_artifact() + assert len(prompt_stack.system_messages) == 1 + assert prompt_stack.messages[0].is_system() + assert prompt_stack.messages[0].content[1].to_text() == "\n\n" + assert "baz" in prompt_stack.messages[0].content[2].to_text() + assert isinstance(output, TextArtifact) + assert output.value == json.dumps({"baz": "foo"}) diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 2cd102bf8..d146d2249 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,5 +1,7 @@ +import pytest +import schema + from griptape.artifacts.image_artifact import ImageArtifact -from griptape.artifacts.json_artifact import JsonArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.memory.structure import ConversationMemory @@ -174,50 +176,6 @@ def test_prompt_stack_empty_system_content(self): assert task.prompt_stack.messages[2].is_user() assert task.prompt_stack.messages[2].to_text() == "test value" - def test_prompt_stack_native_schema(self): - from schema import Schema - - output_schema = Schema({"baz": str}) - task = PromptTask( - input="foo", - prompt_driver=MockPromptDriver( - mock_structured_output={"baz": "foo"}, - structured_output_strategy="native", - ), - output_schema=output_schema, - ) - output = task.run() - - assert isinstance(output, JsonArtifact) - assert output.value == {"baz": "foo"} - - assert task.prompt_stack.output_schema is output_schema - assert task.prompt_stack.messages[0].is_user() - assert "foo" in task.prompt_stack.messages[0].to_text() - - def test_prompt_stack_tool_schema(self): - from schema import Schema - - output_schema = Schema({"baz": str}) - task = PromptTask( - input="foo", - prompt_driver=MockPromptDriver( - mock_structured_output={"baz": "foo"}, - structured_output_strategy="tool", - use_native_tools=True, - ), - output_schema=output_schema, - ) - output = task.run() - - assert isinstance(output, JsonArtifact) - assert output.value == {"baz": "foo"} - - assert task.prompt_stack.output_schema is output_schema - assert task.prompt_stack.messages[0].is_system() - assert task.prompt_stack.messages[1].is_user() - assert "foo" in task.prompt_stack.messages[1].to_text() - def test_prompt_stack_empty_native_schema(self): task = PromptTask( input="foo", @@ -282,3 +240,19 @@ def test_subtasks(self): task.run() assert len(task.subtasks) == 2 + + @pytest.mark.parametrize("structured_output_strategy", ["native", "rule"]) + def test_parse_output(self, structured_output_strategy): + task = PromptTask( + input="foo", + prompt_driver=MockPromptDriver( + structured_output_strategy=structured_output_strategy, + mock_structured_output={"foo": "bar"}, + ), + output_schema=schema.Schema({"foo": str}), + ) + + task.run() + + assert task.output is not None + assert task.output.value == {"foo": "bar"}