Skip to content

Commit

Permalink
Clean up logic
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 31, 2024
1 parent b9d0e82 commit 6fe7f0d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 25 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`.
- Validators to `Agent` initialization.


### Changed

- Rulesets can now be serialized and deserialized.
Expand All @@ -38,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Parsing `ActionCallDeltaMessageContent`s with empty string `partial_input`s.
- `Stream` util not properly propagating thread contextvars.
- `ValueError` with `DuckDuckGoWebSearchDriver`.
- `Agent.stream` overriding `Agent.prompt_driver.stream` even when `Agent.prompt_driver` is explicitly provided.

### Deprecated

Expand Down
31 changes: 12 additions & 19 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import Attribute, define, evolve, field
from attrs import Attribute, define, evolve, field, validators

from griptape.artifacts.text_artifact import TextArtifact
from griptape.common import observable
Expand All @@ -23,8 +23,8 @@ class Agent(Structure):
input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field(
default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
)
stream: Optional[bool] = field(default=None, kw_only=True)
prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True)
stream: bool = field(default=None, kw_only=True)
prompt_driver: BasePromptDriver = field(default=None, kw_only=True)
tools: list[BaseTool] = field(factory=list, kw_only=True)
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
fail_fast: bool = field(default=False, kw_only=True)
Expand All @@ -37,15 +37,6 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB
if fail_fast:
raise ValueError("Agents cannot fail fast, as they can only have 1 task.")

@stream.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001
if stream is not None and self.prompt_driver is not None:
warnings.warn(
"`Agent.stream` is set, but `Agent.prompt_driver` was provided. `Agent.stream` will be ignored. This will be an error in the future.",
UserWarning,
stacklevel=2,
)

@prompt_driver.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_prompt_driver(self, _: Attribute, prompt_driver: Optional[BasePromptDriver]) -> None: # noqa: FBT001
if prompt_driver is not None and self.stream is not None:
Expand Down Expand Up @@ -95,17 +86,19 @@ def try_run(self, *args) -> Agent:
return self

def _init_task(self) -> None:
stream = False if self.stream is None else self.stream
if self.stream is None:
with validators.disabled():
self.stream = Defaults.drivers_config.prompt_driver.stream

if self.prompt_driver is None:
with validators.disabled():
self.prompt_driver = evolve(Defaults.drivers_config.prompt_driver, stream=self.stream)

prompt_driver = (
evolve(Defaults.drivers_config.prompt_driver, stream=stream)
if self.prompt_driver is None
else self.prompt_driver
)
task = PromptTask(
self.input,
prompt_driver=prompt_driver,
prompt_driver=self.prompt_driver,
tools=self.tools,
max_meta_memory_entries=self.max_meta_memory_entries,
)

self.add_task(task)
9 changes: 4 additions & 5 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,6 @@ def test_stream_mutation(self):
assert agent.tasks[0].prompt_driver.stream is False
assert agent.tasks[0].prompt_driver is not prompt_driver

def test_validate_stream(self):
with pytest.warns(UserWarning, match="`Agent.stream` is set, but `Agent.prompt_driver` was provided."):
Agent(stream=True, prompt_driver=MockPromptDriver())

def test_validate_prompt_driver(self):
with pytest.warns(UserWarning, match="`Agent.prompt_driver` is set, but `Agent.stream` was provided."):
Agent(stream=True, prompt_driver=MockPromptDriver())
Expand All @@ -298,17 +294,20 @@ def test_validate_tasks(self):
with pytest.warns(UserWarning, match="`Agent.tasks` is set, but `Agent.prompt_driver` was provided."):
Agent(prompt_driver=MockPromptDriver(), tasks=[PromptTask()])

def test_sugar_fields(self):
def test_field_hierarchy(self):
# Test that stream on its own propagates to the task.
agent = Agent(stream=True)

assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is True

# Test that stream does not propagate to the prompt driver if explicitly provided
agent = Agent(stream=True, prompt_driver=MockPromptDriver())

assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is False

# Test that neither stream nor prompt driver propagate to the task if explicitly provided
agent = Agent(
stream=False,
prompt_driver=MockPromptDriver(stream=False),
Expand Down

0 comments on commit 6fe7f0d

Please sign in to comment.