Skip to content

Commit

Permalink
Add Schema Drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 31, 2024
1 parent 4240f0a commit 56608d2
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 27 deletions.
7 changes: 7 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .schema.base_schema_driver import BaseSchemaDriver
from .schema.schema_schema_driver import SchemaSchemaDriver
from .schema.pydantic_schema_driver import PydanticSchemaDriver

from .prompt.base_prompt_driver import BasePromptDriver
from .prompt.openai_chat_prompt_driver import OpenAiChatPromptDriver
from .prompt.azure_openai_chat_prompt_driver import AzureOpenAiChatPromptDriver
Expand Down Expand Up @@ -240,4 +244,7 @@
"BaseAssistantDriver",
"GriptapeCloudAssistantDriver",
"OpenAiAssistantDriver",
"BaseSchemaDriver",
"SchemaSchemaDriver",
"PydanticSchemaDriver",
]
Empty file.
17 changes: 17 additions & 0 deletions griptape/drivers/schema/base_schema_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar

from attrs import define

from griptape.mixins.serializable_mixin import SerializableMixin

T = TypeVar("T")


@define
class BaseSchemaDriver(ABC, SerializableMixin, Generic[T]):
@abstractmethod
def validate(self, model: T, data: Any) -> Any: ...

@abstractmethod
def to_json_schema(self, model: T) -> dict: ...
15 changes: 15 additions & 0 deletions griptape/drivers/schema/pydantic_schema_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any

from attrs import define
from pydantic import BaseModel

from griptape.drivers import BaseSchemaDriver


@define
class PydanticSchemaDriver(BaseSchemaDriver[BaseModel]):
def validate(self, model: BaseModel, data: Any) -> Any:
return model.model_validate(data)

def to_json_schema(self, model: BaseModel) -> dict:
return model.model_dump()
15 changes: 15 additions & 0 deletions griptape/drivers/schema/schema_schema_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any

from attrs import define
from schema import Schema

from griptape.drivers import BaseSchemaDriver


@define
class SchemaSchemaDriver(BaseSchemaDriver[Schema]):
def validate(self, model: Schema, data: Any) -> Any:
return model.validate(data)

def to_json_schema(self, model: Schema) -> dict:
return model.json_schema("Schema")
37 changes: 10 additions & 27 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import json
import logging
import warnings
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import NOTHING, Attribute, Factory, NothingType, define, field
from schema import Or, Schema
from schema import Schema

from griptape import utils
from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact
Expand All @@ -16,12 +15,12 @@
from griptape.memory.structure import Run
from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin
from griptape.mixins.rule_mixin import RuleMixin
from griptape.rules import JsonSchemaRule, Ruleset
from griptape.rules import Ruleset
from griptape.tasks import ActionsSubtask, BaseTask
from griptape.utils import J2

if TYPE_CHECKING:
from griptape.drivers import BasePromptDriver
from griptape.drivers import BasePromptDriver, BaseSchemaDriver
from griptape.memory import TaskMemory
from griptape.memory.structure.base_conversation_memory import BaseConversationMemory
from griptape.structures import Structure
Expand All @@ -39,6 +38,9 @@ class PromptTask(BaseTask, RuleMixin, ActionsSubtaskOriginMixin):
prompt_driver: BasePromptDriver = field(
default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True, metadata={"serializable": True}
)
output_schema_driver: Optional[BaseSchemaDriver] = field(
default=None, kw_only=True, metadata={"serializable": True}
)
generate_system_template: Callable[[PromptTask], str] = field(
default=Factory(lambda self: self.default_generate_system_template, takes_self=True),
kw_only=True,
Expand Down Expand Up @@ -90,13 +92,14 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask],

@property
def prompt_stack(self) -> PromptStack:
stack = PromptStack(tools=self.tools)
stack = PromptStack(
tools=self.tools,
output_schema=self.output_schema_driver.model if self.output_schema_driver is not None else None,
)
memory = self.structure.conversation_memory if self.structure is not None else None

rulesets = self.rulesets
system_artifacts = [TextArtifact(self.generate_system_template(self))]
if self.prompt_driver.use_native_structured_output:
self._add_native_schema_to_prompt_stack(stack, rulesets)

# Ensure there is at least one Ruleset that has non-empty `rules`.
if any(len(ruleset.rules) for ruleset in rulesets):
Expand Down Expand Up @@ -307,26 +310,6 @@ def _process_task_input(
else:
return self._process_task_input(TextArtifact(task_input))

def _add_native_schema_to_prompt_stack(self, stack: PromptStack, rulesets: list[Ruleset]) -> None:
# Need to separate JsonSchemaRules from other rules, removing them in the process
json_schema_rules = [rule for ruleset in rulesets for rule in ruleset.rules if isinstance(rule, JsonSchemaRule)]
non_json_schema_rules = [
[rule for rule in ruleset.rules if not isinstance(rule, JsonSchemaRule)] for ruleset in rulesets
]
for ruleset, non_json_rules in zip(rulesets, non_json_schema_rules):
ruleset.rules = non_json_rules

schemas = [rule.value for rule in json_schema_rules if isinstance(rule.value, Schema)]

if len(json_schema_rules) != len(schemas):
warnings.warn(
"Not all provided `JsonSchemaRule`s include a `schema.Schema` instance. These will be ignored with `use_native_structured_output`.",
stacklevel=2,
)

if schemas:
stack.output_schema = schemas[0] if len(schemas) == 1 else Schema(Or(*schemas))

def _add_subtasks_to_prompt_stack(self, stack: PromptStack) -> None:
for s in self.subtasks:
if self.prompt_driver.use_native_tools:
Expand Down

0 comments on commit 56608d2

Please sign in to comment.