Skip to content

Commit

Permalink
Revert removal of use_native_structured_output, add fallback to JsonS…
Browse files Browse the repository at this point in the history
…chemaRule
  • Loading branch information
collindutter committed Jan 3, 2025
1 parent 23494a1 commit 541fc54
Show file tree
Hide file tree
Showing 34 changed files with 217 additions and 60 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Structure.run_stream()` for streaming Events from a Structure as an iterator.
- Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`.
- Validators to `Agent` initialization.
- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output.
- `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`.

### Changed
Expand Down
8 changes: 6 additions & 2 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ You can pass images to the Driver if the model supports it:

Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.

You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of:
Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output).

If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool.
Expand All @@ -44,8 +46,10 @@ The easiest way to get started with structured output is by using a [PromptTask]
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
```

If `use_native_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt.

!!! warning
Not every LLM supports all `structured_output_strategy` options.
Not every LLM supports `use_native_structured_output` or all `structured_output_strategy` options.

## Prompt Drivers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_native_structured_output=True, # optional
structured_output_strategy="native", # optional
),
output_schema=schema.Schema(
Expand Down
4 changes: 4 additions & 0 deletions docs/griptape-framework/structures/rulesets.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru

### Json Schema

!!! tip
[Structured Output](../drivers/prompt-drivers.md#structured-output) provides a more robust solution for having the LLM generate structured output.
And if an LLM does not natively support structured output, a `JsonSchemaRule` will automatically be added.

[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema.
This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types.

Expand Down
7 changes: 6 additions & 1 deletion griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -133,7 +134,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"toolChoice": self.tool_choice,
}

if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["toolConfig"]["toolChoice"] = {"any": {}}

Expand Down
7 changes: 6 additions & 1 deletion griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -135,7 +136,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice

if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.structured_output_strategy == "tool"
):
self._add_structured_output_tool(prompt_stack)
params["tool_choice"] = {"type": "any"}

Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
Expand Down
3 changes: 2 additions & 1 deletion griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver):
model: str = field(metadata={"serializable": True})
force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
Expand Down Expand Up @@ -111,7 +112,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None:
if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_object",
Expand Down
7 changes: 6 additions & 1 deletion griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class GooglePromptDriver(BasePromptDriver):
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -163,7 +164,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools:
params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.structured_output_strategy == "tool"
):
params["tool_config"]["function_calling_config"]["mode"] = "auto"
self._add_structured_output_tool(prompt_stack)

Expand Down
7 changes: 6 additions & 1 deletion griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
api_token: str = field(kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -120,7 +121,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema and self.structured_output_strategy == "native":
if (
prompt_stack.output_schema
and self.use_native_structured_output
and self.structured_output_strategy == "native"
):
# https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
output_schema = prompt_stack.output_schema.json_schema("Output Schema")
# Grammar does not support $schema and $id
Expand Down
3 changes: 2 additions & 1 deletion griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class OllamaPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
Expand Down Expand Up @@ -109,7 +110,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None:
if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.structured_output_strategy == "native":
params["format"] = prompt_stack.output_schema.json_schema("Output")
elif self.structured_output_strategy == "tool":
Expand Down
3 changes: 2 additions & 1 deletion griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver):
seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True})
ignored_exception_types: tuple[type[Exception], ...] = field(
default=Factory(
Expand Down Expand Up @@ -158,7 +159,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
params["tool_choice"] = self.tool_choice
params["parallel_tool_calls"] = self.parallel_tool_calls

if prompt_stack.output_schema is not None:
if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_schema",
Expand Down
11 changes: 10 additions & 1 deletion griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def try_run(self) -> BaseArtifact:
else:
output = result.to_artifact()

if self.output_schema is not None and self.prompt_driver.structured_output_strategy == "native":
if (
self.prompt_driver.use_native_structured_output
and self.prompt_driver.structured_output_strategy == "native"
):
return JsonArtifact(output.value)
else:
return output
Expand All @@ -210,6 +213,8 @@ def preprocess(self, structure: Structure) -> BaseTask:
return self

def default_generate_system_template(self, _: PromptTask) -> str:
from griptape.rules import JsonSchemaRule

schema = self.actions_schema().json_schema("Actions Schema")
schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually.

Expand All @@ -219,6 +224,10 @@ def default_generate_system_template(self, _: PromptTask) -> str:
actions_schema=utils.minify_json(json.dumps(schema)),
meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories),
use_native_tools=self.prompt_driver.use_native_tools,
use_native_structured_output=self.prompt_driver.use_native_structured_output,
json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output Schema"))
if self.output_schema is not None
else None,
stop_sequence=self.response_stop_sequence,
)

Expand Down
4 changes: 4 additions & 0 deletions griptape/templates/tasks/prompt_task/system.j2
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER

{{ rulesets }}
{% endif %}
{% if not use_native_structured_output and json_schema_rule %}

{{ json_schema_rule }}
{% endif %}
4 changes: 2 additions & 2 deletions tests/mocks/mock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver):

def try_run(self, prompt_stack: PromptStack) -> Message:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output
if prompt_stack.output_schema is not None:
if self.use_native_structured_output and prompt_stack.output_schema is not None:
if self.structured_output_strategy == "native":
return Message(
content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))],
Expand Down Expand Up @@ -84,7 +84,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output

if prompt_stack.output_schema is not None:
if self.use_native_structured_output and prompt_stack.output_schema is not None:
if self.structured_output_strategy == "native":
yield DeltaMessage(
content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_to_dict(self, config):
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"use_native_tools": True,
"use_native_structured_output": True,
"structured_output_strategy": "tool",
"extra_params": {},
},
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_to_dict_with_values(self, config_with_values):
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"use_native_tools": True,
"use_native_structured_output": True,
"structured_output_strategy": "tool",
"extra_params": {},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_to_dict(self, config):
"top_k": 250,
"use_native_tools": True,
"structured_output_strategy": "tool",
"use_native_structured_output": True,
"extra_params": {},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_to_dict(self, config):
"user": "",
"use_native_tools": True,
"structured_output_strategy": "native",
"use_native_structured_output": True,
"extra_params": {},
},
"conversation_memory_driver": {
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_cohere_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_to_dict(self, config):
"model": "command-r",
"force_single_step": False,
"use_native_tools": True,
"use_native_structured_output": True,
"structured_output_strategy": "native",
"extra_params": {},
},
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_to_dict(self, config):
"max_tokens": None,
"stream": False,
"use_native_tools": False,
"use_native_structured_output": False,
"structured_output_strategy": "native",
"extra_params": {},
},
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_google_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_to_dict(self, config):
"top_k": None,
"tool_choice": "auto",
"use_native_tools": True,
"use_native_structured_output": True,
"structured_output_strategy": "tool",
"extra_params": {},
},
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_to_dict(self, config):
"user": "",
"use_native_tools": True,
"structured_output_strategy": "native",
"use_native_structured_output": True,
"extra_params": {},
},
"conversation_memory_driver": {
Expand Down
Loading

0 comments on commit 541fc54

Please sign in to comment.