Skip to content

Commit

Permalink
Clear and restore event listeners in driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 12, 2024
1 parent 54e680b commit 8b5a875
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Parameter `meta: dict` on `BaseEvent`.
- `AzureOpenAiTextToSpeechDriver`.
- Ability to use Event Listeners as Context Managers for temporarily setting the Event Bus listeners.
- `JsonSchemaRule` for instructing the LLM to output a JSON object that conforms to a schema.
- Ability to use Drivers Configs as Context Managers for temporarily setting the default Drivers.
- `JsonSchemaRule` for instructing the LLM to output a JSON object that conforms to a schema.
- `LocalStructureRunDriver.event_listeners` for adding Event Listeners to a local Structure run.

### Changed
Expand Down
4 changes: 4 additions & 0 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from griptape.artifacts import BaseArtifact, InfoArtifact
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver
from griptape.events import EventBus

if TYPE_CHECKING:
from griptape.events import EventListener
Expand All @@ -28,6 +29,8 @@ class LocalStructureRunDriver(BaseStructureRunDriver):

def try_run(self, *args: BaseArtifact) -> BaseArtifact:
old_env = os.environ.copy()
old_event_listeners = EventBus.event_listeners.copy()
EventBus.clear_event_listeners()
try:
os.environ.update(self.env)

Expand All @@ -36,6 +39,7 @@ def try_run(self, *args: BaseArtifact) -> BaseArtifact:
stack.enter_context(event_listener)
structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])
finally:
EventBus.set_event_listeners(old_event_listeners)
os.environ.clear()
os.environ.update(old_env)

Expand Down
4 changes: 4 additions & 0 deletions griptape/events/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
for event_listener in self._event_listeners:
event_listener.publish_event(event, flush=flush)

def set_event_listeners(self, event_listeners: list[EventListener]) -> None:
with self._thread_lock:
self._event_listeners = event_listeners

def clear_event_listeners(self) -> None:
with self._thread_lock:
self._event_listeners.clear()
Expand Down
4 changes: 0 additions & 4 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class EventListener:
event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True)
driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True)

_last_event_listeners: Optional[list[EventListener]] = field(default=None)

def __enter__(self) -> EventListener:
from griptape.events import EventBus

Expand All @@ -30,8 +28,6 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002

EventBus.remove_event_listener(self)

self._last_event_listeners = None

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
event_types = self.event_types

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from unittest.mock import Mock

from griptape.drivers import LocalStructureRunDriver
from griptape.events import EventBus, EventListener
from griptape.structures import Agent, Pipeline
from griptape.tasks import StructureRunTask
from tests.mocks.mock_prompt_driver import MockPromptDriver
Expand Down Expand Up @@ -28,3 +30,16 @@ def test_run_with_env(self, mock_config):
pipeline.add_task(task)

assert task.run().to_text() == "value"

def test_run_with_event_listeners(self):
event_listeners = [EventListener(), EventListener()]
EventBus.add_event_listeners(event_listeners)
mock_handler = Mock()
driver = LocalStructureRunDriver(
structure_factory_fn=lambda: Agent(), event_listeners=[EventListener(handler=mock_handler)]
)

driver.run()

assert EventBus.event_listeners == event_listeners
mock_handler.assert_called()
5 changes: 5 additions & 0 deletions tests/unit/events/test_event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def test_remove_event_listener(self):
def test_remove_unknown_event_listener(self):
EventBus.remove_event_listener(EventListener())

def test_set_event_listeners(self):
listeners = [EventListener(), EventListener()]
EventBus.set_event_listeners(listeners)
assert EventBus.event_listeners == listeners

def test_publish_event(self):
# Given
mock_handler = Mock()
Expand Down

0 comments on commit 8b5a875

Please sign in to comment.