diff --git a/CHANGELOG.md b/CHANGELOG.md index c975aadcb..659c8d492 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. - Validators to `Agent` initialization. - ### Changed - Rulesets can now be serialized and deserialized. @@ -38,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parsing `ActionCallDeltaMessageContent`s with empty string `partial_input`s. - `Stream` util not properly propagating thread contextvars. - `ValueError` with `DuckDuckGoWebSearchDriver`. +- `Agent.stream` overriding `Agent.prompt_driver.stream` even when `Agent.prompt_driver` is explicitly provided. ### Deprecated diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 011e024a5..baf36108f 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -3,7 +3,7 @@ import warnings from typing import TYPE_CHECKING, Callable, Optional, Union -from attrs import Attribute, define, evolve, field +from attrs import Attribute, define, evolve, field, validators from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable @@ -23,8 +23,8 @@ class Agent(Structure): input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) - stream: Optional[bool] = field(default=None, kw_only=True) - prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True) + stream: bool = field(default=None, kw_only=True) + prompt_driver: BasePromptDriver = field(default=None, kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) @@ -37,15 +37,6 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB if fail_fast: raise ValueError("Agents cannot fail fast, as they can only have 1 task.") - @stream.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001 - if stream is not None and self.prompt_driver is not None: - warnings.warn( - "`Agent.stream` is set, but `Agent.prompt_driver` was provided. `Agent.stream` will be ignored. This will be an error in the future.", - UserWarning, - stacklevel=2, - ) - @prompt_driver.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_prompt_driver(self, _: Attribute, prompt_driver: Optional[BasePromptDriver]) -> None: # noqa: FBT001 if prompt_driver is not None and self.stream is not None: @@ -95,17 +86,19 @@ def try_run(self, *args) -> Agent: return self def _init_task(self) -> None: - stream = False if self.stream is None else self.stream + if self.stream is None: + with validators.disabled(): + self.stream = Defaults.drivers_config.prompt_driver.stream + + if self.prompt_driver is None: + with validators.disabled(): + self.prompt_driver = evolve(Defaults.drivers_config.prompt_driver, stream=self.stream) - prompt_driver = ( - evolve(Defaults.drivers_config.prompt_driver, stream=stream) - if self.prompt_driver is None - else self.prompt_driver - ) task = PromptTask( self.input, - prompt_driver=prompt_driver, + prompt_driver=self.prompt_driver, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries, ) + self.add_task(task) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index d1cd618ef..809d174b5 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -286,10 +286,6 @@ def test_stream_mutation(self): assert agent.tasks[0].prompt_driver.stream is False assert agent.tasks[0].prompt_driver is not prompt_driver - def test_validate_stream(self): - with pytest.warns(UserWarning, match="`Agent.stream` is set, but `Agent.prompt_driver` was provided."): - Agent(stream=True, prompt_driver=MockPromptDriver()) - def test_validate_prompt_driver(self): with pytest.warns(UserWarning, match="`Agent.prompt_driver` is set, but `Agent.stream` was provided."): Agent(stream=True, prompt_driver=MockPromptDriver()) @@ -298,17 +294,20 @@ def test_validate_tasks(self): with pytest.warns(UserWarning, match="`Agent.tasks` is set, but `Agent.prompt_driver` was provided."): Agent(prompt_driver=MockPromptDriver(), tasks=[PromptTask()]) - def test_sugar_fields(self): + def test_field_hierarchy(self): + # Test that stream on its own propagates to the task. agent = Agent(stream=True) assert isinstance(agent.tasks[0], PromptTask) assert agent.tasks[0].prompt_driver.stream is True + # Test that stream does not propagate to the prompt driver if explicitly provided agent = Agent(stream=True, prompt_driver=MockPromptDriver()) assert isinstance(agent.tasks[0], PromptTask) assert agent.tasks[0].prompt_driver.stream is False + # Test that neither stream nor prompt driver propagate to the task if explicitly provided agent = Agent( stream=False, prompt_driver=MockPromptDriver(stream=False),