Skip to content

Commit

Permalink
Remove use_native_structured_output toggle
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 2, 2025
1 parent 3762fc7 commit e5ca892
Show file tree
Hide file tree
Showing 33 changed files with 72 additions and 200 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Structure.run_stream()` for streaming Events from a Structure as an iterator.
- Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`.
- Validators to `Agent` initialization.
- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output.
- `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`.

### Changed
Expand Down
6 changes: 2 additions & 4 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ You can pass images to the Driver if the model supports it:

Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.

Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output).

If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of:
If an [output_schema](../../reference/griptape/tasks.md#griptape.tasks.PromptTask.output_schema) is provided to the Task, you can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output_tool.md) and try to force the LLM to use a Tool.
Expand All @@ -47,7 +45,7 @@ The easiest way to get started with structured output is by using a [JsonSchemaR
```

!!! warning
Not every LLM supports `use_native_structured_output` or all `structured_output_strategy` options.
Not every LLM supports all `structured_output_strategy` options.

## Prompt Drivers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
use_native_structured_output=True,
structured_output_strategy="native",
),
output_schema=schema.Schema(
Expand Down
7 changes: 1 addition & 6 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -134,11 +133,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"toolChoice": self.tool_choice,
}

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.structured_output_strategy == "tool"
):
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
self._add_structured_output_tool(prompt_stack)
params["toolConfig"]["toolChoice"] = {"any": {}}

Expand Down
7 changes: 1 addition & 6 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -136,11 +135,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.structured_output_strategy == "tool"
):
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
self._add_structured_output_tool(prompt_stack)
params["tool_choice"] = {"type": "any"}

Expand Down
1 change: 0 additions & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
Expand Down
3 changes: 1 addition & 2 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class CoherePromptDriver(BasePromptDriver):
model: str = field(metadata={"serializable": True})
force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
Expand Down Expand Up @@ -112,7 +111,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if prompt_stack.output_schema is not None:
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_object",
Expand Down
7 changes: 1 addition & 6 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class GooglePromptDriver(BasePromptDriver):
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -164,11 +163,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools:
params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.structured_output_strategy == "tool"
):
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["tool_config"]["function_calling_config"]["mode"] = "auto"
self._add_structured_output_tool(prompt_stack)

Expand Down
7 changes: 1 addition & 6 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
api_token: str = field(kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -121,11 +120,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if (
prompt_stack.output_schema
and self.use_native_structured_output
and self.structured_output_strategy == "native"
):
if prompt_stack.output_schema and self.structured_output_strategy == "native":
# https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
output_schema = prompt_stack.output_schema.json_schema("Output Schema")
# Grammar does not support $schema and $id
Expand Down
3 changes: 1 addition & 2 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class OllamaPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
Expand Down Expand Up @@ -110,7 +109,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if prompt_stack.output_schema is not None:
if self.structured_output_strategy == "native":
params["format"] = prompt_stack.output_schema.json_schema("Output")
elif self.structured_output_strategy == "tool":
Expand Down
3 changes: 1 addition & 2 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class OpenAiChatPromptDriver(BasePromptDriver):
seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True})
ignored_exception_types: tuple[type[Exception], ...] = field(
default=Factory(
Expand Down Expand Up @@ -159,7 +158,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
params["tool_choice"] = self.tool_choice
params["parallel_tool_calls"] = self.parallel_tool_calls

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if prompt_stack.output_schema is not None:
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_schema",
Expand Down
3 changes: 3 additions & 0 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def _resolve_types(cls, attrs_cls: type) -> None:
from collections.abc import Sequence
from typing import Any

from schema import Schema

from griptape.artifacts import BaseArtifact
from griptape.common import (
BaseDeltaMessageContent,
Expand Down Expand Up @@ -228,6 +230,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
if is_dependency_installed("mypy_boto3_bedrock")
else Any,
"voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any,
"Schema": Schema,
},
)

Expand Down
5 changes: 1 addition & 4 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,7 @@ def try_run(self) -> BaseArtifact:
else:
output = result.to_artifact()

if (
self.prompt_driver.use_native_structured_output
and self.prompt_driver.structured_output_strategy == "native"
):
if self.output_schema is not None and self.prompt_driver.structured_output_strategy == "native":
return JsonArtifact(output.value)
else:
return output
Expand Down
12 changes: 11 additions & 1 deletion tests/mocks/mock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver):

def try_run(self, prompt_stack: PromptStack) -> Message:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output
if self.use_native_structured_output and prompt_stack.output_schema:
if prompt_stack.output_schema is not None:
if self.structured_output_strategy == "native":
return Message(
content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))],
Expand Down Expand Up @@ -84,6 +84,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output

if prompt_stack.output_schema is not None:
if self.structured_output_strategy == "native":
yield DeltaMessage(
content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)),
role=Message.ASSISTANT_ROLE,
usage=Message.Usage(input_tokens=100, output_tokens=100),
)
elif self.structured_output_strategy == "tool":
self._add_structured_output_tool(prompt_stack)

if self.use_native_tools and prompt_stack.tools:
# Hack to simulate CoT. If there are any action messages in the prompt stack, give the answer.
action_messages = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def test_to_dict(self, config):
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"use_native_tools": True,
"use_native_structured_output": True,
"structured_output_strategy": "tool",
"extra_params": {},
},
Expand Down Expand Up @@ -108,7 +107,6 @@ def test_to_dict_with_values(self, config_with_values):
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"use_native_tools": True,
"use_native_structured_output": True,
"structured_output_strategy": "tool",
"extra_params": {},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def test_to_dict(self, config):
"top_k": 250,
"use_native_tools": True,
"structured_output_strategy": "tool",
"use_native_structured_output": True,
"extra_params": {},
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test_to_dict(self, config):
"user": "",
"use_native_tools": True,
"structured_output_strategy": "native",
"use_native_structured_output": True,
"extra_params": {},
},
"conversation_memory_driver": {
Expand Down
1 change: 0 additions & 1 deletion tests/unit/configs/drivers/test_cohere_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def test_to_dict(self, config):
"model": "command-r",
"force_single_step": False,
"use_native_tools": True,
"use_native_structured_output": True,
"structured_output_strategy": "native",
"extra_params": {},
},
Expand Down
1 change: 0 additions & 1 deletion tests/unit/configs/drivers/test_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_to_dict(self, config):
"max_tokens": None,
"stream": False,
"use_native_tools": False,
"use_native_structured_output": False,
"structured_output_strategy": "native",
"extra_params": {},
},
Expand Down
1 change: 0 additions & 1 deletion tests/unit/configs/drivers/test_google_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def test_to_dict(self, config):
"top_k": None,
"tool_choice": "auto",
"use_native_tools": True,
"use_native_structured_output": True,
"structured_output_strategy": "tool",
"extra_params": {},
},
Expand Down
1 change: 0 additions & 1 deletion tests/unit/configs/drivers/test_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_to_dict(self, config):
"user": "",
"use_native_tools": True,
"structured_output_strategy": "native",
"use_native_structured_output": True,
"extra_params": {},
},
"conversation_memory_driver": {
Expand Down
22 changes: 10 additions & 12 deletions tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,11 @@ def messages(self):
]

@pytest.mark.parametrize("use_native_tools", [True, False])
@pytest.mark.parametrize("use_native_structured_output", [True, False])
def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_native_structured_output):
def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools):
# Given
driver = AmazonBedrockPromptDriver(
model="ai21.j2",
use_native_tools=use_native_tools,
use_native_structured_output=use_native_structured_output,
extra_params={"foo": "bar"},
)

Expand All @@ -414,11 +412,13 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools,
*self.BEDROCK_TOOLS,
*(
[self.BEDROCK_STRUCTURED_OUTPUT_TOOL]
if use_native_structured_output and driver.structured_output_strategy == "tool"
if driver.structured_output_strategy == "tool"
else []
),
],
"toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice,
"toolChoice": {"any": {}}
if driver.structured_output_strategy == "tool"
else driver.tool_choice,
}
}
if use_native_tools
Expand All @@ -437,16 +437,12 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools,
assert message.usage.output_tokens == 10

@pytest.mark.parametrize("use_native_tools", [True, False])
@pytest.mark.parametrize("use_native_structured_output", [True, False])
def test_try_stream_run(
self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_native_structured_output
):
def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools):
# Given
driver = AmazonBedrockPromptDriver(
model="ai21.j2",
stream=True,
use_native_tools=use_native_tools,
use_native_structured_output=use_native_structured_output,
extra_params={"foo": "bar"},
)

Expand All @@ -471,11 +467,13 @@ def test_try_stream_run(
*self.BEDROCK_TOOLS,
*(
[self.BEDROCK_STRUCTURED_OUTPUT_TOOL]
if use_native_structured_output and driver.structured_output_strategy == "tool"
if driver.structured_output_strategy == "tool"
else []
),
],
"toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice,
"toolChoice": {"any": {}}
if driver.structured_output_strategy == "tool"
else driver.tool_choice,
}
}
if use_native_tools
Expand Down
Loading

0 comments on commit e5ca892

Please sign in to comment.