From 87c8bb32e07507efe0250288a28c0d52e850db93 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 31 Dec 2024 12:25:33 -0800 Subject: [PATCH] WIP --- docs/griptape-framework/misc/events.md | 3 +- griptape/common/prompt_stack/prompt_stack.py | 6 +-- .../prompt/amazon_bedrock_prompt_driver.py | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 4 ++ .../prompt/openai_chat_prompt_driver.py | 4 +- griptape/schemas/base_schema.py | 2 + griptape/tasks/prompt_task.py | 30 ++++--------- tests/mocks/mock_prompt_driver.py | 2 +- .../test_amazon_bedrock_drivers_config.py | 2 + .../drivers/test_anthropic_drivers_config.py | 1 + .../test_azure_openai_drivers_config.py | 1 + .../drivers/test_cohere_drivers_config.py | 1 + .../configs/drivers/test_drivers_config.py | 1 + .../drivers/test_google_drivers_config.py | 1 + .../drivers/test_openai_driver_config.py | 1 + tests/unit/structures/test_structure.py | 1 + tests/unit/tasks/test_prompt_task.py | 42 +------------------ tests/unit/tasks/test_tool_task.py | 1 + tests/unit/tasks/test_toolkit_task.py | 1 + 19 files changed, 35 insertions(+), 73 deletions(-) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 0d30ed1e9..90df9434e 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -79,7 +79,8 @@ Handler 2 list[Message]: diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 9e754f6aa..20ee1b455 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -168,9 +168,7 @@ def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]: "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "inputSchema": { - "json": (tool.activity_schema(activity) or Schema({})).json_schema( - "http://json-schema.org/draft-07/schema#", - ), + "json": self.schema_driver.to_json_schema(tool.activity_schema(activity) or Schema({})), }, }, } diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 19109f55f..53bdc7f9a 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -17,6 +17,7 @@ TextMessageContent, observable, ) +from griptape.drivers import BaseSchemaDriver, SchemaSchemaDriver from griptape.events import ( ActionChunkEvent, EventBus, @@ -61,6 +62,9 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): default="native", kw_only=True, metadata={"serializable": True} ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) + schema_driver: BaseSchemaDriver = field( + default=Factory(lambda: SchemaSchemaDriver()), kw_only=True, metadata={"serializable": True} + ) def before_run(self, prompt_stack: PromptStack) -> None: EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index d8f61a3bf..06da7e77f 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -165,7 +165,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "type": "json_schema", "json_schema": { "name": "Output", - "schema": prompt_stack.output_schema.json_schema("Output"), + "schema": self.schema_driver.to_json_schema(prompt_stack.output_schema), "strict": True, }, } @@ -254,7 +254,7 @@ def __to_openai_tools(self, tools: list[BaseTool]) -> list[dict]: "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), - "parameters": (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema"), + "parameters": self.schema_driver.to_json_schema(tool.activity_schema(activity) or Schema({})), }, "type": "function", } diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 9217f26c2..7f4300bd5 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -169,6 +169,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: BaseImageGenerationDriver, BasePromptDriver, BaseRulesetDriver, + BaseSchemaDriver, BaseTextToSpeechDriver, BaseVectorStoreDriver, ) @@ -196,6 +197,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BaseRulesetDriver": BaseRulesetDriver, "BaseImageGenerationDriver": BaseImageGenerationDriver, + "BaseSchemaDriver": BaseSchemaDriver, "BaseArtifact": BaseArtifact, "PromptStack": PromptStack, "EventListener": EventListener, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 99f0daa5d..6ff21a127 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -2,7 +2,7 @@ import json import logging -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from attrs import NOTHING, Attribute, Factory, NothingType, define, field from schema import Schema @@ -20,7 +20,7 @@ from griptape.utils import J2 if TYPE_CHECKING: - from griptape.drivers import BasePromptDriver, BaseSchemaDriver + from griptape.drivers import BasePromptDriver from griptape.memory import TaskMemory from griptape.memory.structure.base_conversation_memory import BaseConversationMemory from griptape.structures import Structure @@ -38,9 +38,7 @@ class PromptTask(BaseTask, RuleMixin, ActionsSubtaskOriginMixin): prompt_driver: BasePromptDriver = field( default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True, metadata={"serializable": True} ) - output_schema_driver: Optional[BaseSchemaDriver] = field( - default=None, kw_only=True, metadata={"serializable": True} - ) + output_schema: Optional[Any] = field(default=None, kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_generate_system_template, takes_self=True), kw_only=True, @@ -92,23 +90,12 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: - stack = PromptStack( - tools=self.tools, - output_schema=self.output_schema_driver.model if self.output_schema_driver is not None else None, - ) + stack = PromptStack(tools=self.tools, output_schema=self.output_schema) memory = self.structure.conversation_memory if self.structure is not None else None - rulesets = self.rulesets - system_artifacts = [TextArtifact(self.generate_system_template(self))] - - # Ensure there is at least one Ruleset that has non-empty `rules`. - if any(len(ruleset.rules) for ruleset in rulesets): - system_artifacts.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=rulesets))) - - # Ensure there is at least one system Artifact that has a non-empty value. - has_system_artifacts = any(system_artifact.value for system_artifact in system_artifacts) - if has_system_artifacts: - stack.add_system_message(ListArtifact(system_artifacts)) + system_template = self.generate_system_template(self) + if system_template: + stack.add_system_message(system_template) stack.add_user_message(self.input) @@ -119,7 +106,7 @@ def prompt_stack(self) -> PromptStack: if memory is not None: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if has_system_artifacts else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) return stack @@ -229,6 +216,7 @@ def default_generate_system_template(self, _: PromptTask) -> str: schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. return J2("tasks/prompt_task/system.j2").render( + rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index abef72227..af4b2c79a 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if prompt_stack.output_schema and self.use_native_structured_output: + if self.use_native_structured_output and prompt_stack.output_schema: if self.native_structured_output_strategy == "native": return Message( content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 59eb4ac61..e3e46dec6 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -54,6 +54,7 @@ def test_to_dict(self, config): "use_native_structured_output": True, "native_structured_output_strategy": "tool", "extra_params": {}, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "vector_store_driver": { "embedding_driver": { @@ -111,6 +112,7 @@ def test_to_dict_with_values(self, config_with_values): "use_native_structured_output": True, "native_structured_output_strategy": "tool", "extra_params": {}, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "vector_store_driver": { "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 66f987308..6c43448b5 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -28,6 +28,7 @@ def test_to_dict(self, config): "native_structured_output_strategy": "tool", "use_native_structured_output": True, "extra_params": {}, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 2281f4c11..0cf1958a0 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -39,6 +39,7 @@ def test_to_dict(self, config): "native_structured_output_strategy": "native", "use_native_structured_output": True, "extra_params": {}, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "conversation_memory_driver": { "type": "LocalConversationMemoryDriver", diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 6f371c5ba..97f2f6ea4 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -29,6 +29,7 @@ def test_to_dict(self, config): "use_native_structured_output": True, "native_structured_output_strategy": "native", "extra_params": {}, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "embedding_driver": { "type": "CohereEmbeddingDriver", diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index dd2e1736b..f5b217410 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -21,6 +21,7 @@ def test_to_dict(self, config): "use_native_structured_output": False, "native_structured_output_strategy": "native", "extra_params": {}, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "conversation_memory_driver": { "type": "LocalConversationMemoryDriver", diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 569e45561..fb9aa2f8c 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -28,6 +28,7 @@ def test_to_dict(self, config): "use_native_structured_output": True, "native_structured_output_strategy": "tool", "extra_params": {}, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index 603d9867a..f9cc312ad 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -31,6 +31,7 @@ def test_to_dict(self, config): "native_structured_output_strategy": "native", "use_native_structured_output": True, "extra_params": {}, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "conversation_memory_driver": { "type": "LocalConversationMemoryDriver", diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 5921d9e28..8aade1dc4 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -85,6 +85,7 @@ def test_to_dict(self): "use_native_tools": False, "use_native_structured_output": False, "native_structured_output_strategy": "native", + "schema_driver": {"type": "SchemaSchemaDriver"}, }, } ], diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 30a7001f9..e4d3060a5 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,7 +1,5 @@ import warnings -import pytest - from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.json_artifact import JsonArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -188,7 +186,7 @@ def test_prompt_stack_native_schema(self): use_native_structured_output=True, mock_structured_output={"baz": "foo"}, ), - rules=[JsonSchemaRule(output_schema)], + output_schema=output_schema, ) output = task.run() @@ -204,27 +202,6 @@ def test_prompt_stack_native_schema(self): warnings.simplefilter("error") assert task.prompt_stack - def test_prompt_stack_mixed_native_schema(self): - from schema import Schema - - output_schema = Schema({"baz": str}) - task = PromptTask( - input="foo", - prompt_driver=MockPromptDriver( - use_native_structured_output=True, - ), - rules=[Rule("foo"), JsonSchemaRule({"bar": {}}), JsonSchemaRule(output_schema)], - ) - - assert task.prompt_stack.output_schema is output_schema - assert task.prompt_stack.messages[0].is_system() - assert "foo" in task.prompt_stack.messages[0].to_text() - assert "bar" not in task.prompt_stack.messages[0].to_text() - with pytest.warns( - match="Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`." - ): - assert task.prompt_stack - def test_prompt_stack_empty_native_schema(self): task = PromptTask( input="foo", @@ -236,23 +213,6 @@ def test_prompt_stack_empty_native_schema(self): assert task.prompt_stack.output_schema is None - def test_prompt_stack_multi_native_schema(self): - from schema import Or, Schema - - output_schema = Schema({"foo": str}) - task = PromptTask( - input="foo", - prompt_driver=MockPromptDriver( - use_native_structured_output=True, - ), - rules=[JsonSchemaRule({"foo": {}}), JsonSchemaRule(output_schema), JsonSchemaRule(output_schema)], - ) - - assert isinstance(task.prompt_stack.output_schema, Schema) - assert task.prompt_stack.output_schema.json_schema("Output") == Schema( - Or(output_schema, output_schema) - ).json_schema("Output") - def test_rulesets(self): pipeline = Pipeline( rulesets=[Ruleset("Pipeline Ruleset")], diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index d7050c8f6..aa31a718c 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -260,6 +260,7 @@ def test_to_dict(self): "native_structured_output_strategy": "native", "use_native_structured_output": False, "use_native_tools": False, + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "tool": { "type": task.tool.type, diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 2503f1174..2cd2543b5 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -401,6 +401,7 @@ def test_to_dict(self): "use_native_tools": False, "use_native_structured_output": False, "native_structured_output_strategy": "native", + "schema_driver": {"type": "SchemaSchemaDriver"}, }, "tools": [ {