Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 31, 2024
1 parent 56608d2 commit 87c8bb3
Show file tree
Hide file tree
Showing 19 changed files with 35 additions and 73 deletions.
3 changes: 2 additions & 1 deletion docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ Handler 2 <class 'griptape.events.finish_structure_run_event.FinishStructureRunE
You can use `Structure.run_stream()` for streaming Events from the `Structure` in the form of an iterator.

!!! tip
Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.

Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.

```python
--8<-- "docs/griptape-framework/misc/src/events_streaming.py"
Expand Down
6 changes: 2 additions & 4 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

Expand All @@ -24,16 +24,14 @@
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from schema import Schema

from griptape.tools import BaseTool


@define
class PromptStack(SerializableMixin):
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
tools: list[BaseTool] = field(factory=list, kw_only=True)
output_schema: Optional[Schema] = field(default=None, kw_only=True)
output_schema: Optional[Any] = field(default=None, kw_only=True)

@property
def system_messages(self) -> list[Message]:
Expand Down
4 changes: 1 addition & 3 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]:
"name": tool.to_native_tool_name(activity),
"description": tool.activity_description(activity),
"inputSchema": {
"json": (tool.activity_schema(activity) or Schema({})).json_schema(
"http://json-schema.org/draft-07/schema#",
),
"json": self.schema_driver.to_json_schema(tool.activity_schema(activity) or Schema({})),
},
},
}
Expand Down
4 changes: 4 additions & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TextMessageContent,
observable,
)
from griptape.drivers import BaseSchemaDriver, SchemaSchemaDriver
from griptape.events import (
ActionChunkEvent,
EventBus,
Expand Down Expand Up @@ -61,6 +62,9 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
default="native", kw_only=True, metadata={"serializable": True}
)
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
schema_driver: BaseSchemaDriver = field(
default=Factory(lambda: SchemaSchemaDriver()), kw_only=True, metadata={"serializable": True}
)

def before_run(self, prompt_stack: PromptStack) -> None:
EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))
Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"type": "json_schema",
"json_schema": {
"name": "Output",
"schema": prompt_stack.output_schema.json_schema("Output"),
"schema": self.schema_driver.to_json_schema(prompt_stack.output_schema),
"strict": True,
},
}
Expand Down Expand Up @@ -254,7 +254,7 @@ def __to_openai_tools(self, tools: list[BaseTool]) -> list[dict]:
"function": {
"name": tool.to_native_tool_name(activity),
"description": tool.activity_description(activity),
"parameters": (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema"),
"parameters": self.schema_driver.to_json_schema(tool.activity_schema(activity) or Schema({})),
},
"type": "function",
}
Expand Down
2 changes: 2 additions & 0 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
BaseImageGenerationDriver,
BasePromptDriver,
BaseRulesetDriver,
BaseSchemaDriver,
BaseTextToSpeechDriver,
BaseVectorStoreDriver,
)
Expand Down Expand Up @@ -196,6 +197,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseConversationMemoryDriver": BaseConversationMemoryDriver,
"BaseRulesetDriver": BaseRulesetDriver,
"BaseImageGenerationDriver": BaseImageGenerationDriver,
"BaseSchemaDriver": BaseSchemaDriver,
"BaseArtifact": BaseArtifact,
"PromptStack": PromptStack,
"EventListener": EventListener,
Expand Down
30 changes: 9 additions & 21 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from attrs import NOTHING, Attribute, Factory, NothingType, define, field
from schema import Schema
Expand All @@ -20,7 +20,7 @@
from griptape.utils import J2

if TYPE_CHECKING:
from griptape.drivers import BasePromptDriver, BaseSchemaDriver
from griptape.drivers import BasePromptDriver
from griptape.memory import TaskMemory
from griptape.memory.structure.base_conversation_memory import BaseConversationMemory
from griptape.structures import Structure
Expand All @@ -38,9 +38,7 @@ class PromptTask(BaseTask, RuleMixin, ActionsSubtaskOriginMixin):
prompt_driver: BasePromptDriver = field(
default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True, metadata={"serializable": True}
)
output_schema_driver: Optional[BaseSchemaDriver] = field(
default=None, kw_only=True, metadata={"serializable": True}
)
output_schema: Optional[Any] = field(default=None, kw_only=True)
generate_system_template: Callable[[PromptTask], str] = field(
default=Factory(lambda self: self.default_generate_system_template, takes_self=True),
kw_only=True,
Expand Down Expand Up @@ -92,23 +90,12 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask],

@property
def prompt_stack(self) -> PromptStack:
stack = PromptStack(
tools=self.tools,
output_schema=self.output_schema_driver.model if self.output_schema_driver is not None else None,
)
stack = PromptStack(tools=self.tools, output_schema=self.output_schema)
memory = self.structure.conversation_memory if self.structure is not None else None

rulesets = self.rulesets
system_artifacts = [TextArtifact(self.generate_system_template(self))]

# Ensure there is at least one Ruleset that has non-empty `rules`.
if any(len(ruleset.rules) for ruleset in rulesets):
system_artifacts.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=rulesets)))

# Ensure there is at least one system Artifact that has a non-empty value.
has_system_artifacts = any(system_artifact.value for system_artifact in system_artifacts)
if has_system_artifacts:
stack.add_system_message(ListArtifact(system_artifacts))
system_template = self.generate_system_template(self)
if system_template:
stack.add_system_message(system_template)

stack.add_user_message(self.input)

Expand All @@ -119,7 +106,7 @@ def prompt_stack(self) -> PromptStack:

if memory is not None:
# inserting at index 1 to place memory right after system prompt
memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if has_system_artifacts else 0)
memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0)

return stack

Expand Down Expand Up @@ -229,6 +216,7 @@ def default_generate_system_template(self, _: PromptTask) -> str:
schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually.

return J2("tasks/prompt_task/system.j2").render(
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
action_names=str.join(", ", [tool.name for tool in self.tools]),
actions_schema=utils.minify_json(json.dumps(schema)),
meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories),
Expand Down
2 changes: 1 addition & 1 deletion tests/mocks/mock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver):

def try_run(self, prompt_stack: PromptStack) -> Message:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output
if prompt_stack.output_schema and self.use_native_structured_output:
if self.use_native_structured_output and prompt_stack.output_schema:
if self.native_structured_output_strategy == "native":
return Message(
content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_to_dict(self, config):
"use_native_structured_output": True,
"native_structured_output_strategy": "tool",
"extra_params": {},
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"vector_store_driver": {
"embedding_driver": {
Expand Down Expand Up @@ -111,6 +112,7 @@ def test_to_dict_with_values(self, config_with_values):
"use_native_structured_output": True,
"native_structured_output_strategy": "tool",
"extra_params": {},
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"vector_store_driver": {
"embedding_driver": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_to_dict(self, config):
"native_structured_output_strategy": "tool",
"use_native_structured_output": True,
"extra_params": {},
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"embedding_driver": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_to_dict(self, config):
"native_structured_output_strategy": "native",
"use_native_structured_output": True,
"extra_params": {},
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_cohere_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_to_dict(self, config):
"use_native_structured_output": True,
"native_structured_output_strategy": "native",
"extra_params": {},
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"embedding_driver": {
"type": "CohereEmbeddingDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_to_dict(self, config):
"use_native_structured_output": False,
"native_structured_output_strategy": "native",
"extra_params": {},
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_google_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_to_dict(self, config):
"use_native_structured_output": True,
"native_structured_output_strategy": "tool",
"extra_params": {},
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"embedding_driver": {
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_to_dict(self, config):
"native_structured_output_strategy": "native",
"use_native_structured_output": True,
"extra_params": {},
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"conversation_memory_driver": {
"type": "LocalConversationMemoryDriver",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/structures/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_to_dict(self):
"use_native_tools": False,
"use_native_structured_output": False,
"native_structured_output_strategy": "native",
"schema_driver": {"type": "SchemaSchemaDriver"},
},
}
],
Expand Down
42 changes: 1 addition & 41 deletions tests/unit/tasks/test_prompt_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings

import pytest

from griptape.artifacts.image_artifact import ImageArtifact
from griptape.artifacts.json_artifact import JsonArtifact
from griptape.artifacts.list_artifact import ListArtifact
Expand Down Expand Up @@ -188,7 +186,7 @@ def test_prompt_stack_native_schema(self):
use_native_structured_output=True,
mock_structured_output={"baz": "foo"},
),
rules=[JsonSchemaRule(output_schema)],
output_schema=output_schema,
)
output = task.run()

Expand All @@ -204,27 +202,6 @@ def test_prompt_stack_native_schema(self):
warnings.simplefilter("error")
assert task.prompt_stack

def test_prompt_stack_mixed_native_schema(self):
from schema import Schema

output_schema = Schema({"baz": str})
task = PromptTask(
input="foo",
prompt_driver=MockPromptDriver(
use_native_structured_output=True,
),
rules=[Rule("foo"), JsonSchemaRule({"bar": {}}), JsonSchemaRule(output_schema)],
)

assert task.prompt_stack.output_schema is output_schema
assert task.prompt_stack.messages[0].is_system()
assert "foo" in task.prompt_stack.messages[0].to_text()
assert "bar" not in task.prompt_stack.messages[0].to_text()
with pytest.warns(
match="Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`."
):
assert task.prompt_stack

def test_prompt_stack_empty_native_schema(self):
task = PromptTask(
input="foo",
Expand All @@ -236,23 +213,6 @@ def test_prompt_stack_empty_native_schema(self):

assert task.prompt_stack.output_schema is None

def test_prompt_stack_multi_native_schema(self):
from schema import Or, Schema

output_schema = Schema({"foo": str})
task = PromptTask(
input="foo",
prompt_driver=MockPromptDriver(
use_native_structured_output=True,
),
rules=[JsonSchemaRule({"foo": {}}), JsonSchemaRule(output_schema), JsonSchemaRule(output_schema)],
)

assert isinstance(task.prompt_stack.output_schema, Schema)
assert task.prompt_stack.output_schema.json_schema("Output") == Schema(
Or(output_schema, output_schema)
).json_schema("Output")

def test_rulesets(self):
pipeline = Pipeline(
rulesets=[Ruleset("Pipeline Ruleset")],
Expand Down
1 change: 1 addition & 0 deletions tests/unit/tasks/test_tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def test_to_dict(self):
"native_structured_output_strategy": "native",
"use_native_structured_output": False,
"use_native_tools": False,
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"tool": {
"type": task.tool.type,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/tasks/test_toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def test_to_dict(self):
"use_native_tools": False,
"use_native_structured_output": False,
"native_structured_output_strategy": "native",
"schema_driver": {"type": "SchemaSchemaDriver"},
},
"tools": [
{
Expand Down

0 comments on commit 87c8bb3

Please sign in to comment.