From 0bbcccabe560885deabdba7eeeddc7ceaf951cd9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 10:28:22 -0700 Subject: [PATCH] Add `LocalStructureRunDriver.event_listeners` to allow passing Event Listeners to be active for a Structure's run --- CHANGELOG.md | 1 + .../structure_run/local_structure_run_driver.py | 16 +++++++++++++++- griptape/events/event_bus.py | 5 ++--- tests/unit/events/test_event_listener.py | 4 ++-- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 788dd2e236..ca43a96d01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `meta: dict` on `BaseEvent`. - `AzureOpenAiTextToSpeechDriver`. - Ability to use Event Listeners as Context Managers for temporarily setting the Event Bus listeners. +- `LocalStructureRunDriver.event_listeners` for adding Event Listeners to a local Structure run. ### Changed - **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`. diff --git a/griptape/drivers/structure_run/local_structure_run_driver.py b/griptape/drivers/structure_run/local_structure_run_driver.py index c0049b29aa..e2080de93b 100644 --- a/griptape/drivers/structure_run/local_structure_run_driver.py +++ b/griptape/drivers/structure_run/local_structure_run_driver.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from contextlib import ExitStack from typing import TYPE_CHECKING, Callable from attrs import define, field @@ -9,18 +10,31 @@ from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver if TYPE_CHECKING: + from griptape.events import EventListener from griptape.structures import Structure @define class LocalStructureRunDriver(BaseStructureRunDriver): + """Runs a structure locally. + + Attributes: + structure_factory_fn: A function that returns a Structure. + event_listeners: A list of Event Listeners to add to the Event Bus for the Structure's run. + """ + structure_factory_fn: Callable[[], Structure] = field(kw_only=True) + event_listeners: list[EventListener] = field(factory=list, kw_only=True) def try_run(self, *args: BaseArtifact) -> BaseArtifact: old_env = os.environ.copy() try: os.environ.update(self.env) - structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) + + with ExitStack() as stack: + for event_listener in self.event_listeners: + stack.enter_context(event_listener) + structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) finally: os.environ.clear() os.environ.update(old_env) diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index f6d797473c..b7954480e1 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -24,9 +24,8 @@ def add_event_listeners(self, event_listeners: list[EventListener]) -> list[Even return [self.add_event_listener(event_listener) for event_listener in event_listeners] def remove_event_listeners(self, event_listeners: list[EventListener]) -> None: - with self._thread_lock: - for event_listener in event_listeners: - self.remove_event_listener(event_listener) + for event_listener in event_listeners: + self.remove_event_listener(event_listener) def add_event_listener(self, event_listener: EventListener) -> EventListener: with self._thread_lock: diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 6af1213b16..f35bc5416f 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -94,8 +94,8 @@ def test_add_remove_event_listener(self, pipeline): mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, id="1", event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, id="1", event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent]))