Skip to content

Commit

Permalink
Fix Agent unintentionally modifying stream for all Prompt Drivers (
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Dec 13, 2024
1 parent 7073c50 commit d69bd8c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## Fixed

- Exception when calling `Structure.to_json()` after it has run.

### Added

- `PromptTask.conversation_memory` for setting the Conversation Memory on a Prompt Task.
Expand All @@ -19,6 +15,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BranchTask` for selecting which Tasks (if any) to run based on a condition.
- Support for `BranchTask` in `StructureVisualizer`.

### Fixed

- Exception when calling `Structure.to_json()` after it has run.
- `Agent` unintentionally modifying `stream` for all Prompt Drivers.

## [1.0.0] - 2024-12-09

### Added
Expand Down
7 changes: 5 additions & 2 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import Attribute, Factory, define, field
from attrs import Attribute, Factory, define, evolve, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.common import observable
Expand All @@ -24,7 +24,10 @@ class Agent(Structure):
)
stream: bool = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver.stream), kw_only=True)
prompt_driver: BasePromptDriver = field(
default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True
default=Factory(
lambda self: evolve(Defaults.drivers_config.prompt_driver, stream=self.stream), takes_self=True
),
kw_only=True,
)
tools: list[BaseTool] = field(factory=list, kw_only=True)
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,11 @@ def test_is_running(self):
task.state = BaseTask.State.RUNNING

assert agent.is_running()

def test_stream_mutation(self):
prompt_driver = MockPromptDriver()
agent = Agent(prompt_driver=MockPromptDriver(), stream=True)

assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is True
assert agent.tasks[0].prompt_driver is not prompt_driver

0 comments on commit d69bd8c

Please sign in to comment.