Skip to content

Commit

Permalink
Clean up Agent initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 30, 2024
1 parent 0cac633 commit 8cfaf18
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 39 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `EvalEngine` for evaluating the performance of an LLM's output against a given input.
- `BaseFileLoader.save()` method for saving an Artifact to a destination.
- `Structure.run_stream()` for streaming Events from a Structure as an iterator.
- Validators to `Agent` initialization.


### Changed

Expand Down
67 changes: 51 additions & 16 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import Attribute, Factory, define, evolve, field
from attrs import Attribute, define, evolve, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.common import observable
Expand All @@ -22,34 +23,52 @@ class Agent(Structure):
input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field(
default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
)
stream: bool = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver.stream), kw_only=True)
prompt_driver: BasePromptDriver = field(
default=Factory(
lambda self: evolve(Defaults.drivers_config.prompt_driver, stream=self.stream), takes_self=True
),
kw_only=True,
)
stream: Optional[bool] = field(default=None, kw_only=True)
prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True)
tools: list[BaseTool] = field(factory=list, kw_only=True)
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
fail_fast: bool = field(default=False, kw_only=True)
_tasks: list[Union[BaseTask, list[BaseTask]]] = field(
factory=list, kw_only=True, alias="tasks", metadata={"serializable": True}
)

@fail_fast.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FBT001
if fail_fast:
raise ValueError("Agents cannot fail fast, as they can only have 1 task.")

@stream.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001
if stream is not None and self.prompt_driver is not None:
warnings.warn(
"`Agent.stream` is set, but `Agent.prompt_driver` was provided. `Agent.stream` will be ignored. This will be an error in the future.",
UserWarning,
stacklevel=2,
)

@prompt_driver.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_prompt_driver(self, _: Attribute, prompt_driver: Optional[BasePromptDriver]) -> None: # noqa: FBT001
if prompt_driver is not None and self.stream is not None:
warnings.warn(
"`Agent.prompt_driver` is set, but `Agent.stream` was provided. `Agent.stream` will be ignored. This will be an error in the future.",
UserWarning,
stacklevel=2,
)

@_tasks.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_tasks(self, _: Attribute, tasks: list) -> None:
if tasks and self.prompt_driver is not None:
warnings.warn(
"`Agent.tasks` is set, but `Agent.prompt_driver` was provided. `Agent.prompt_driver` will be ignored. This will be an error in the future.",
UserWarning,
stacklevel=2,
)

def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()

self.prompt_driver.stream = self.stream
if len(self.tasks) == 0:
task = PromptTask(
self.input,
prompt_driver=self.prompt_driver,
tools=self.tools,
max_meta_memory_entries=self.max_meta_memory_entries,
)
self.add_task(task)
self._init_task()

@property
def task(self) -> BaseTask:
Expand All @@ -74,3 +93,19 @@ def try_run(self, *args) -> Agent:
self.task.run()

return self

def _init_task(self) -> None:
stream = False if self.stream is None else self.stream

prompt_driver = (
evolve(Defaults.drivers_config.prompt_driver, stream=stream)
if self.prompt_driver is None
else self.prompt_driver
)
task = PromptTask(
self.input,
prompt_driver=prompt_driver,
tools=self.tools,
max_meta_memory_entries=self.max_meta_memory_entries,
)
self.add_task(task)
37 changes: 35 additions & 2 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,41 @@ def test_is_running(self):

def test_stream_mutation(self):
prompt_driver = MockPromptDriver()
agent = Agent(prompt_driver=MockPromptDriver(), stream=True)

agent = Agent(prompt_driver=MockPromptDriver(stream=False), stream=True)

assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is True
assert agent.tasks[0].prompt_driver.stream is False
assert agent.tasks[0].prompt_driver is not prompt_driver

def test_validate_stream(self):
with pytest.warns(UserWarning, match="`Agent.stream` is set, but `Agent.prompt_driver` was provided."):
Agent(stream=True, prompt_driver=MockPromptDriver())

def test_validate_prompt_driver(self):
with pytest.warns(UserWarning, match="`Agent.prompt_driver` is set, but `Agent.stream` was provided."):
Agent(stream=True, prompt_driver=MockPromptDriver())

def test_validate_tasks(self):
with pytest.warns(UserWarning, match="`Agent.tasks` is set, but `Agent.prompt_driver` was provided."):
Agent(prompt_driver=MockPromptDriver(), tasks=[PromptTask()])

def test_sugar_fields(self):
agent = Agent(stream=True)

assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is True

agent = Agent(stream=True, prompt_driver=MockPromptDriver())

assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is False

agent = Agent(
stream=False,
prompt_driver=MockPromptDriver(stream=False),
tasks=[PromptTask(prompt_driver=MockPromptDriver(stream=True))],
)

assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is True
50 changes: 29 additions & 21 deletions tests/unit/utils/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,35 @@


class TestStream:
@pytest.fixture(params=[True, False])
def agent(self, request):
driver = MockPromptDriver(
use_native_tools=request.param,
)
return Agent(stream=request.param, tools=[MockTool()], prompt_driver=driver)

def test_init(self, agent):
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("use_native_tools", [True, False])
def test_init(self, stream, use_native_tools):
prompt_driver = MockPromptDriver(use_native_tools=use_native_tools, stream=stream)
agent = Agent(tools=[MockTool()], prompt_driver=prompt_driver)
chat_stream = Stream(agent)
if agent.stream:
assert chat_stream.structure == agent
chat_stream_run = chat_stream.run()
assert isinstance(chat_stream_run, Iterator)
assert next(chat_stream_run).value == "MockTool.mock-tag (test)"
assert next(chat_stream_run).value == json.dumps({"values": {"test": "test-value"}}, indent=2)
next(chat_stream_run)
assert next(chat_stream_run).value == "Answer: mock output"
next(chat_stream_run)
with pytest.raises(StopIteration):

chat_stream_run = chat_stream.run()
assert chat_stream.structure == agent
assert isinstance(chat_stream_run, Iterator)
if prompt_driver.stream:
if use_native_tools:
assert next(chat_stream_run).value == "MockTool.mock-tag (test)"
assert next(chat_stream_run).value == json.dumps({"values": {"test": "test-value"}}, indent=2)
next(chat_stream_run)
assert next(chat_stream_run).value == "Answer: mock output"
next(chat_stream_run)
with pytest.raises(StopIteration):
next(chat_stream_run)
else:
assert next(chat_stream_run).value == "mock output"
else:
assert next(chat_stream.run()).value == "\n"
with pytest.raises(StopIteration):
next(chat_stream.run())
# MockPromptDriver produces some extra events because it simulates CoT when using native tools.
if use_native_tools:
assert next(chat_stream_run).value == "\n"
assert next(chat_stream_run).value == "\n"
with pytest.raises(StopIteration):
next(chat_stream_run)
else:
assert next(chat_stream_run).value == "\n"
with pytest.raises(StopIteration):
next(chat_stream_run)

0 comments on commit 8cfaf18

Please sign in to comment.