diff --git a/CHANGELOG.md b/CHANGELOG.md index e07f38305..9c322a4db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Structure.run_stream()` for streaming Events from a Structure as an iterator. - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. +- `BasePromptDriver.use_native_structured_output` for enabling or disabling structured output. - `BasePromptDriver.structured_output_strategy` for changing the structured output strategy between `native` and `tool`. ### Changed diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index dfc2c8b56..aede7fe01 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -29,7 +29,9 @@ You can pass images to the Driver if the model supports it: Some LLMs provide functionality often referred to as "Structured Output". This means instructing the LLM to output data in a particular format, usually JSON. This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems. -You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of: +Structured output can be enabled or disabled for a Prompt Driver by setting the [use_native_structured_output](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.use_native_structured_output). + +If `use_native_structured_output=True`, you can change _how_ the output is structured by setting the [structured_output_strategy](../../reference/griptape/drivers.md#griptape.drivers.BasePromptDriver.structured_output_strategy) to one of: - `native`: The Driver will use the LLM's structured output functionality provided by the API. - `tool`: Griptape will pass a special Tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md) and try to force the LLM to use a Tool. @@ -44,8 +46,10 @@ The easiest way to get started with structured output is by using a [PromptTask] --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py" ``` +If `use_native_structured_output=False`, the Task will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. + !!! warning - Not every LLM supports all `structured_output_strategy` options. + Not every LLM supports `use_native_structured_output` or all `structured_output_strategy` options. ## Prompt Drivers diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py index cb7eb5ceb..adc7ea7ad 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py @@ -11,6 +11,7 @@ PromptTask( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", + use_native_structured_output=True, # optional structured_output_strategy="native", # optional ), output_schema=schema.Schema( diff --git a/docs/griptape-framework/structures/rulesets.md b/docs/griptape-framework/structures/rulesets.md index f7a1de482..93e5a4c2b 100644 --- a/docs/griptape-framework/structures/rulesets.md +++ b/docs/griptape-framework/structures/rulesets.md @@ -26,6 +26,10 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru ### Json Schema +!!! tip + [Structured Output](../drivers/prompt-drivers.md#structured-output) provides a more robust solution for having the LLM generate structured output. + And if an LLM does not natively support structured output, a `JsonSchemaRule` will automatically be added. + [JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema. This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types. diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 7a8c1b470..eefee0ff2 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -55,6 +55,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -133,7 +134,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "toolChoice": self.tool_choice, } - if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.structured_output_strategy == "tool" + ): self._add_structured_output_tool(prompt_stack) params["toolConfig"]["toolChoice"] = {"any": {}} diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 48e8ac18b..99053713a 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -68,6 +68,7 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -135,7 +136,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice - if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.structured_output_strategy == "tool" + ): self._add_structured_output_tool(prompt_stack) params["tool_choice"] = {"type": "any"} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index d13a045c3..950c80cf8 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -56,6 +56,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index c7438aa99..a7121b440 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -111,7 +112,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None: + if prompt_stack.output_schema is not None and self.use_native_structured_output: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_object", diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index ff486167b..29c43a91e 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -63,6 +63,7 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -163,7 +164,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} - if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.structured_output_strategy == "tool" + ): params["tool_config"]["function_calling_config"]["mode"] = "auto" self._add_structured_output_tool(prompt_stack) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 62f463a1b..5b24f083b 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -35,6 +35,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: Literal["native", "tool"] = field( default="native", kw_only=True, metadata={"serializable": True} ) @@ -120,7 +121,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema and self.structured_output_strategy == "native": + if ( + prompt_stack.output_schema + and self.use_native_structured_output + and self.structured_output_strategy == "native" + ): # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.output_schema.json_schema("Output Schema") # Grammar does not support $schema and $id diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 734a73308..295d926d1 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,6 +68,7 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -109,7 +110,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None: + if prompt_stack.output_schema is not None and self.use_native_structured_output: if self.structured_output_strategy == "native": params["format"] = prompt_stack.output_schema.json_schema("Output") elif self.structured_output_strategy == "tool": diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 56b1b3405..69e615585 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -76,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -158,7 +159,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["tool_choice"] = self.tool_choice params["parallel_tool_calls"] = self.parallel_tool_calls - if prompt_stack.output_schema is not None: + if prompt_stack.output_schema is not None and self.use_native_structured_output: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 276c2c229..cd00ec574 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -190,7 +190,10 @@ def try_run(self) -> BaseArtifact: else: output = result.to_artifact() - if self.output_schema is not None and self.prompt_driver.structured_output_strategy == "native": + if ( + self.prompt_driver.use_native_structured_output + and self.prompt_driver.structured_output_strategy == "native" + ): return JsonArtifact(output.value) else: return output @@ -210,6 +213,8 @@ def preprocess(self, structure: Structure) -> BaseTask: return self def default_generate_system_template(self, _: PromptTask) -> str: + from griptape.rules import JsonSchemaRule + schema = self.actions_schema().json_schema("Actions Schema") schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. @@ -219,6 +224,10 @@ def default_generate_system_template(self, _: PromptTask) -> str: actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), use_native_tools=self.prompt_driver.use_native_tools, + use_native_structured_output=self.prompt_driver.use_native_structured_output, + json_schema_rule=JsonSchemaRule(self.output_schema.json_schema("Output Schema")) + if self.output_schema is not None + else None, stop_sequence=self.response_stop_sequence, ) diff --git a/griptape/templates/tasks/prompt_task/system.j2 b/griptape/templates/tasks/prompt_task/system.j2 index b262e7c72..4dcd34ee5 100644 --- a/griptape/templates/tasks/prompt_task/system.j2 +++ b/griptape/templates/tasks/prompt_task/system.j2 @@ -26,3 +26,7 @@ NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER {{ rulesets }} {% endif %} +{% if not use_native_structured_output and json_schema_rule %} + +{{ json_schema_rule }} +{% endif %} diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 782c8ecd4..01824af06 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -36,7 +36,7 @@ class MockPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if prompt_stack.output_schema is not None: + if self.use_native_structured_output and prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": return Message( content=[TextMessageContent(TextArtifact(json.dumps(self.mock_structured_output)))], @@ -84,7 +84,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - if prompt_stack.output_schema is not None: + if self.use_native_structured_output and prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": yield DeltaMessage( content=TextDeltaMessageContent(json.dumps(self.mock_structured_output)), diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index b2fd51d24..77c2631f3 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,6 +51,7 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, @@ -107,6 +108,7 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index fa13480c1..f412e10cb 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -26,6 +26,7 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, "structured_output_strategy": "tool", + "use_native_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index a30cea001..45fbfd6ab 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -37,6 +37,7 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 94e258e36..0c2e665a6 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,6 +26,7 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, + "use_native_structured_output": True, "structured_output_strategy": "native", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index 15646cc1d..f425913b5 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,6 +18,7 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, + "use_native_structured_output": False, "structured_output_strategy": "native", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index 910ae3240..3c8ef0e0e 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,6 +25,7 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, + "use_native_structured_output": True, "structured_output_strategy": "tool", "extra_params": {}, }, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index 344d14d99..bc9b02cd3 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -29,6 +29,7 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, "structured_output_strategy": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index b31776f63..81c642814 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -384,11 +384,13 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -412,13 +414,11 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if driver.structured_output_strategy == "tool" + if use_native_structured_output and driver.structured_output_strategy == "tool" else [] ), ], - "toolChoice": {"any": {}} - if driver.structured_output_strategy == "tool" - else driver.tool_choice, + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, } } if use_native_tools @@ -437,12 +437,16 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AmazonBedrockPromptDriver( model="ai21.j2", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -467,13 +471,11 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ *self.BEDROCK_TOOLS, *( [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] - if driver.structured_output_strategy == "tool" + if use_native_structured_output and driver.structured_output_strategy == "tool" else [] ), ], - "toolChoice": {"any": {}} - if driver.structured_output_strategy == "tool" - else driver.tool_choice, + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, } } if use_native_tools diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 147c69103..687db3b68 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -370,12 +370,14 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -395,11 +397,15 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): **{ "tools": [ *self.ANTHROPIC_TOOLS, - *([self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.structured_output_strategy == "tool" + else [] + ), ] if use_native_tools else {}, - "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, } if use_native_tools else {}, @@ -416,13 +422,17 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_stream_client, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -444,11 +454,15 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na **{ "tools": [ *self.ANTHROPIC_TOOLS, - *([self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.structured_output_strategy == "tool" + else [] + ), ] if use_native_tools else {}, - "tool_choice": {"type": "any"} if driver.structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, } if use_native_tools else {}, diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 3c8d39475..f7f153dd0 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -67,6 +67,7 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_run( self, @@ -74,6 +75,7 @@ def test_try_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -82,6 +84,7 @@ def test_try_run( azure_deployment="deployment-id", model="gpt-4", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -98,9 +101,15 @@ def test_try_run( **{ "tools": [ *self.OPENAI_TOOLS, - *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ], - "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": "required" + if use_native_structured_output and structured_output_strategy == "tool" + else driver.tool_choice, } if use_native_tools else {}, @@ -114,7 +123,7 @@ def test_try_run( }, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -127,6 +136,7 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool"]) def test_try_stream_run( self, @@ -134,6 +144,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -143,6 +154,7 @@ def test_try_stream_run( model="gpt-4", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -161,9 +173,15 @@ def test_try_stream_run( **{ "tools": [ *self.OPENAI_TOOLS, - *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ], - "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": "required" + if use_native_structured_output and structured_output_strategy == "tool" + else driver.tool_choice, } if use_native_tools else {}, @@ -177,7 +195,7 @@ def test_try_stream_run( }, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 17e9251d3..ad417cac5 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -338,6 +338,7 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -345,6 +346,7 @@ def test_try_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -352,6 +354,7 @@ def test_try_run( model="command", api_key="api-key", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -367,7 +370,11 @@ def test_try_run( **{ "tools": [ *self.COHERE_TOOLS, - *([self.COHERE_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ] } if use_native_tools @@ -378,7 +385,7 @@ def test_try_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, @@ -399,6 +406,7 @@ def test_try_run( assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -406,6 +414,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -414,6 +423,7 @@ def test_try_stream_run( api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -430,7 +440,11 @@ def test_try_stream_run( **{ "tools": [ *self.COHERE_TOOLS, - *([self.COHERE_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ] } if use_native_tools @@ -441,7 +455,7 @@ def test_try_stream_run( "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, stop_sequences=[], temperature=0.1, diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index cc17de3c1..a0b68a6af 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -177,7 +177,10 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run( + self, mock_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -185,6 +188,7 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy="tool", extra_params={"max_output_tokens": 10}, ) @@ -209,11 +213,11 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native tool_declarations = call_args.kwargs["tools"] tools = [ *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), ] assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if driver.structured_output_strategy == "tool": + if use_native_structured_output: assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) @@ -227,7 +231,10 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream( + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -236,6 +243,7 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"max_output_tokens": 10}, ) @@ -261,11 +269,11 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, tool_declarations = call_args.kwargs["tools"] tools = [ *self.GOOGLE_TOOLS, - *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if driver.structured_output_strategy == "tool" else []), + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), ] assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools - if driver.structured_output_strategy == "tool": + if use_native_structured_output: assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index a65befbce..334c1649e 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -54,11 +54,13 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - def test_try_run(self, prompt_stack, mock_client): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -71,18 +73,22 @@ def test_try_run(self, prompt_stack, mock_client): return_full_text=False, max_new_tokens=250, foo="bar", - grammar={"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_client_stream): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structured_output): # Given driver = HuggingFaceHubPromptDriver( api_token="api-token", model="repo-id", stream=True, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -96,7 +102,9 @@ def test_try_stream(self, prompt_stack, mock_client_stream): return_full_text=False, max_new_tokens=250, foo="bar", - grammar={"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}, + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 46c3ef4af..cffcd3954 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -232,6 +232,7 @@ def test_init(self): assert OllamaPromptDriver(model="llama") @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -239,12 +240,14 @@ def test_try_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given driver = OllamaPromptDriver( model="llama", use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -264,12 +267,18 @@ def test_try_run( **{ "tools": [ *self.OLLAMA_TOOLS, - *([self.OLLAMA_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ] } if use_native_tools else {}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and structured_output_strategy == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -281,6 +290,7 @@ def test_try_run( assert message.value[1].value.input == {"foo": "bar"} @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -288,6 +298,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -295,6 +306,7 @@ def test_try_stream_run( model="llama", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -307,7 +319,9 @@ def test_try_stream_run( messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, - **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} if structured_output_strategy == "native" else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and structured_output_strategy == "native" + else {}, stream=True, foo="bar", ) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 44c3ecba4..ed6085538 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -371,6 +371,7 @@ def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_run( self, @@ -378,12 +379,14 @@ def test_try_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -401,9 +404,15 @@ def test_try_run( **{ "tools": [ *self.OPENAI_TOOLS, - *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ], - "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": "required" + if use_native_structured_output and structured_output_strategy == "tool" + else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -418,7 +427,7 @@ def test_try_run( }, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -500,6 +509,7 @@ def test_try_run_response_format_json_schema(self, mock_chat_completion_create, assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "foo"]) def test_try_stream_run( self, @@ -507,6 +517,7 @@ def test_try_stream_run( prompt_stack, messages, use_native_tools, + use_native_structured_output, structured_output_strategy, ): # Given @@ -514,6 +525,7 @@ def test_try_stream_run( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, structured_output_strategy=structured_output_strategy, extra_params={"foo": "bar"}, ) @@ -534,9 +546,15 @@ def test_try_stream_run( **{ "tools": [ *self.OPENAI_TOOLS, - *([self.OPENAI_STRUCTURED_OUTPUT_TOOL] if structured_output_strategy == "tool" else []), + *( + [self.OPENAI_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and structured_output_strategy == "tool" + else [] + ), ], - "tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice, + "tool_choice": "required" + if use_native_structured_output and structured_output_strategy == "tool" + else driver.tool_choice, "parallel_tool_calls": driver.parallel_tool_calls, } if use_native_tools @@ -551,7 +569,7 @@ def test_try_stream_run( }, } } - if structured_output_strategy == "native" + if use_native_structured_output and structured_output_strategy == "native" else {}, foo="bar", ) @@ -578,11 +596,11 @@ def test_try_stream_run( def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given - prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False, + use_native_structured_output=False, ) # When @@ -612,12 +630,12 @@ def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completio assert e.value.args[0] == "Completion with more than one choice is not supported yet." def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): - prompt_stack.output_schema = None driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, use_native_tools=False, + use_native_structured_output=False, ) # When diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 34471fb39..3344644a3 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -83,6 +83,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, "structured_output_strategy": "native", }, } diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index fba790470..e4d3060a5 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -183,6 +183,7 @@ def test_prompt_stack_native_schema(self): task = PromptTask( input="foo", prompt_driver=MockPromptDriver( + use_native_structured_output=True, mock_structured_output={"baz": "foo"}, ), output_schema=output_schema, @@ -204,7 +205,9 @@ def test_prompt_stack_native_schema(self): def test_prompt_stack_empty_native_schema(self): task = PromptTask( input="foo", - prompt_driver=MockPromptDriver(), + prompt_driver=MockPromptDriver( + use_native_structured_output=True, + ), rules=[JsonSchemaRule({"foo": {}})], ) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index ba419480d..f3a18b1e2 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -258,6 +258,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "structured_output_strategy": "native", + "use_native_structured_output": False, "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 3a7476596..082ccc466 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,6 +399,7 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, "structured_output_strategy": "native", }, "tools": [