Skip to content

Commit

Permalink
Move logic from task to base prompt driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 3, 2025
1 parent c9bcefa commit d068967
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 68 deletions.
4 changes: 2 additions & 2 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ The easiest way to get started with structured output is by using a [PromptTask]
You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: The Task will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and the Driver will try to force the LLM to use the Tool.
- `rule`: The Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort.
- `tool`: The Driver will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and will try to force the LLM to use the Tool.
- `rule`: The Driver will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort.

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
Expand Down
32 changes: 31 additions & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from attrs import Factory, define, field

from griptape.artifacts.base_artifact import BaseArtifact
from griptape.artifacts import BaseArtifact, TextArtifact
from griptape.common import (
ActionCallDeltaMessageContent,
ActionCallMessageContent,
Expand All @@ -26,6 +26,7 @@
)
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.rules.json_schema_rule import JsonSchemaRule

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -64,6 +65,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
self._init_structured_output(prompt_stack)
EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))

def after_run(self, result: Message) -> None:
Expand Down Expand Up @@ -127,6 +129,34 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

def _init_structured_output(self, prompt_stack: PromptStack) -> None:
from griptape.tools import StructuredOutputTool

if (output_schema := prompt_stack.output_schema) is not None:
if self.structured_output_strategy == "tool":
structured_output_tool = StructuredOutputTool(output_schema=output_schema)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)
elif self.structured_output_strategy == "rule":
output_artifact = TextArtifact(JsonSchemaRule(output_schema.json_schema("Output Schema")).to_text())
system_messages = prompt_stack.system_messages
if system_messages:
last_system_message = prompt_stack.system_messages[-1]
last_system_message.content.extend(
[
TextMessageContent(TextArtifact("\n\n")),
TextMessageContent(output_artifact),
]
)
else:
prompt_stack.messages.insert(
0,
Message(
content=[TextMessageContent(output_artifact)],
role=Message.SYSTEM_ROLE,
),
)

def __process_run(self, prompt_stack: PromptStack) -> Message:
return self.try_run(prompt_stack)

Expand Down
14 changes: 1 addition & 13 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin
from griptape.mixins.rule_mixin import RuleMixin
from griptape.rules import Ruleset
from griptape.rules.json_schema_rule import JsonSchemaRule
from griptape.tasks import ActionsSubtask, BaseTask
from griptape.utils import J2

Expand Down Expand Up @@ -92,16 +91,9 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask],

@property
def prompt_stack(self) -> PromptStack:
from griptape.tools.structured_output.tool import StructuredOutputTool

stack = PromptStack(tools=self.tools)
stack = PromptStack(tools=self.tools, output_schema=self.output_schema)
memory = self.structure.conversation_memory if self.structure is not None else None

if self.output_schema is not None:
stack.output_schema = self.output_schema
if self.prompt_driver.structured_output_strategy == "tool":
stack.tools.append(StructuredOutputTool(output_schema=stack.output_schema))

system_template = self.generate_system_template(self)
if system_template:
stack.add_system_message(system_template)
Expand Down Expand Up @@ -227,10 +219,6 @@ def default_generate_system_template(self, _: PromptTask) -> str:
actions_schema=utils.minify_json(json.dumps(schema)),
meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories),
use_native_tools=self.prompt_driver.use_native_tools,
structured_output_strategy=self.prompt_driver.structured_output_strategy,
json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output"))
if self.output_schema is not None
else None,
stop_sequence=self.response_stop_sequence,
)

Expand Down
4 changes: 0 additions & 4 deletions griptape/templates/tasks/prompt_task/system.j2
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,3 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER

{{ rulesets }}
{% endif %}
{% if json_schema_rule and structured_output_strategy == 'rule' %}

{{ json_schema_rule }}
{% endif %}
2 changes: 0 additions & 2 deletions tests/mocks/mock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class MockPromptDriver(BasePromptDriver):

def try_run(self, prompt_stack: PromptStack) -> Message:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output

if self.use_native_tools and prompt_stack.tools:
# Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer.
action_messages = [
Expand Down Expand Up @@ -85,7 +84,6 @@ def try_run(self, prompt_stack: PromptStack) -> Message:

def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output

if self.use_native_tools and prompt_stack.tools:
# Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer.
action_messages = [
Expand Down
76 changes: 75 additions & 1 deletion tests/unit/drivers/prompt/test_base_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from griptape.artifacts import ErrorArtifact, TextArtifact
import json

from griptape.artifacts import ActionArtifact, ErrorArtifact, TextArtifact
from griptape.common import Message, PromptStack
from griptape.events import FinishPromptEvent, StartPromptEvent
from griptape.events.event_bus import _EventBus
from griptape.structures import Pipeline
from griptape.tasks import PromptTask
from griptape.tools.structured_output.tool import StructuredOutputTool
from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver
from tests.mocks.mock_prompt_driver import MockPromptDriver
from tests.mocks.mock_tool.tool import MockTool
Expand Down Expand Up @@ -65,3 +68,74 @@ def test_run_with_tools_and_stream(self, mock_config):
output = pipeline.run().output_task.output
assert isinstance(output, TextArtifact)
assert output.value == "mock output"

def test_native_structured_output_strategy(self):
from schema import Schema

prompt_driver = MockPromptDriver(
mock_structured_output={"baz": "foo"},
structured_output_strategy="native",
)

output_schema = Schema({"baz": str})
output = prompt_driver.run(PromptStack(messages=[], output_schema=output_schema)).to_artifact()

assert isinstance(output, TextArtifact)
assert output.value == json.dumps({"baz": "foo"})

def test_tool_structured_output_strategy(self):
from schema import Schema

output_schema = Schema({"baz": str})
prompt_driver = MockPromptDriver(
mock_structured_output={"baz": "foo"},
structured_output_strategy="tool",
use_native_tools=True,
)
prompt_stack = PromptStack(messages=[], output_schema=output_schema)
output = prompt_driver.run(prompt_stack).to_artifact()
output = prompt_driver.run(prompt_stack).to_artifact()

assert isinstance(output, ActionArtifact)
assert isinstance(prompt_stack.tools[0], StructuredOutputTool)
assert prompt_stack.tools[0].output_schema == output_schema
assert output.value.input == {"values": {"baz": "foo"}}

def test_rule_structured_output_strategy_empty(self):
from schema import Schema

output_schema = Schema({"baz": str})
prompt_driver = MockPromptDriver(
mock_structured_output={"baz": "foo"},
structured_output_strategy="rule",
)
prompt_stack = PromptStack(messages=[], output_schema=output_schema)
output = prompt_driver.run(prompt_stack).to_artifact()

assert len(prompt_stack.system_messages) == 1
assert prompt_stack.messages[0].is_system()
assert "baz" in prompt_stack.messages[0].content[0].to_text()
assert isinstance(output, TextArtifact)
assert output.value == json.dumps({"baz": "foo"})

def test_rule_structured_output_strategy_populated(self):
from schema import Schema

output_schema = Schema({"baz": str})
prompt_driver = MockPromptDriver(
mock_structured_output={"baz": "foo"},
structured_output_strategy="rule",
)
prompt_stack = PromptStack(
messages=[
Message(content="foo", role=Message.SYSTEM_ROLE),
],
output_schema=output_schema,
)
output = prompt_driver.run(prompt_stack).to_artifact()
assert len(prompt_stack.system_messages) == 1
assert prompt_stack.messages[0].is_system()
assert prompt_stack.messages[0].content[1].to_text() == "\n\n"
assert "baz" in prompt_stack.messages[0].content[2].to_text()
assert isinstance(output, TextArtifact)
assert output.value == json.dumps({"baz": "foo"})
64 changes: 19 additions & 45 deletions tests/unit/tasks/test_prompt_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import schema

from griptape.artifacts.image_artifact import ImageArtifact
from griptape.artifacts.json_artifact import JsonArtifact
from griptape.artifacts.list_artifact import ListArtifact
from griptape.artifacts.text_artifact import TextArtifact
from griptape.memory.structure import ConversationMemory
Expand Down Expand Up @@ -174,50 +176,6 @@ def test_prompt_stack_empty_system_content(self):
assert task.prompt_stack.messages[2].is_user()
assert task.prompt_stack.messages[2].to_text() == "test value"

def test_prompt_stack_native_schema(self):
from schema import Schema

output_schema = Schema({"baz": str})
task = PromptTask(
input="foo",
prompt_driver=MockPromptDriver(
mock_structured_output={"baz": "foo"},
structured_output_strategy="native",
),
output_schema=output_schema,
)
output = task.run()

assert isinstance(output, JsonArtifact)
assert output.value == {"baz": "foo"}

assert task.prompt_stack.output_schema is output_schema
assert task.prompt_stack.messages[0].is_user()
assert "foo" in task.prompt_stack.messages[0].to_text()

def test_prompt_stack_tool_schema(self):
from schema import Schema

output_schema = Schema({"baz": str})
task = PromptTask(
input="foo",
prompt_driver=MockPromptDriver(
mock_structured_output={"baz": "foo"},
structured_output_strategy="tool",
use_native_tools=True,
),
output_schema=output_schema,
)
output = task.run()

assert isinstance(output, JsonArtifact)
assert output.value == {"baz": "foo"}

assert task.prompt_stack.output_schema is output_schema
assert task.prompt_stack.messages[0].is_system()
assert task.prompt_stack.messages[1].is_user()
assert "foo" in task.prompt_stack.messages[1].to_text()

def test_prompt_stack_empty_native_schema(self):
task = PromptTask(
input="foo",
Expand Down Expand Up @@ -282,3 +240,19 @@ def test_subtasks(self):

task.run()
assert len(task.subtasks) == 2

@pytest.mark.parametrize("structured_output_strategy", ["native", "rule"])
def test_parse_output(self, structured_output_strategy):
task = PromptTask(
input="foo",
prompt_driver=MockPromptDriver(
structured_output_strategy=structured_output_strategy,
mock_structured_output={"foo": "bar"},
),
output_schema=schema.Schema({"foo": str}),
)

task.run()

assert task.output is not None
assert task.output.value == {"foo": "bar"}

0 comments on commit d068967

Please sign in to comment.