diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d81a27aa..fe7b7707b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased -## Fixed - -- Exception when calling `Structure.to_json()` after it has run. - ### Added - `PromptTask.conversation_memory` for setting the Conversation Memory on a Prompt Task. @@ -19,6 +15,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BranchTask` for selecting which Tasks (if any) to run based on a condition. - Support for `BranchTask` in `StructureVisualizer`. +### Fixed + +- Exception when calling `Structure.to_json()` after it has run. +- `Agent` unintentionally modifying `stream` for all Prompt Drivers. + ## [1.0.0] - 2024-12-09 ### Added diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index be1b73e34..128c02faa 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union -from attrs import Attribute, Factory, define, field +from attrs import Attribute, Factory, define, evolve, field from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable @@ -24,7 +24,10 @@ class Agent(Structure): ) stream: bool = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver.stream), kw_only=True) prompt_driver: BasePromptDriver = field( - default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True + default=Factory( + lambda self: evolve(Defaults.drivers_config.prompt_driver, stream=self.stream), takes_self=True + ), + kw_only=True, ) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 27211f29e..387910f40 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -276,3 +276,11 @@ def test_is_running(self): task.state = BaseTask.State.RUNNING assert agent.is_running() + + def test_stream_mutation(self): + prompt_driver = MockPromptDriver() + agent = Agent(prompt_driver=MockPromptDriver(), stream=True) + + assert isinstance(agent.tasks[0], PromptTask) + assert agent.tasks[0].prompt_driver.stream is True + assert agent.tasks[0].prompt_driver is not prompt_driver