Skip to content

Commit

Permalink
Support using multiple/concurrent EventListeners
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 11, 2024
1 parent 82c9b15 commit 38ad801
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
22 changes: 14 additions & 8 deletions griptape/events/event_bus.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import threading
from typing import TYPE_CHECKING

from attrs import define, field
from attrs import Factory, define, field

from griptape.mixins.singleton_mixin import SingletonMixin

Expand All @@ -13,6 +14,7 @@
@define
class _EventBus(SingletonMixin):
_event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners")
_thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()), alias="_thread_lock")

@property
def event_listeners(self) -> list[EventListener]:
Expand All @@ -22,23 +24,27 @@ def add_event_listeners(self, event_listeners: list[EventListener]) -> list[Even
return [self.add_event_listener(event_listener) for event_listener in event_listeners]

def remove_event_listeners(self, event_listeners: list[EventListener]) -> None:
for event_listener in event_listeners:
self.remove_event_listener(event_listener)
with self._thread_lock:
for event_listener in event_listeners:
self.remove_event_listener(event_listener)

def add_event_listener(self, event_listener: EventListener) -> EventListener:
if event_listener not in self._event_listeners:
self._event_listeners.append(event_listener)
with self._thread_lock:
if event_listener not in self._event_listeners:
self._event_listeners.append(event_listener)

return event_listener

def set_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]:
self._event_listeners = event_listeners
with self._thread_lock:
self._event_listeners = event_listeners

return self._event_listeners

def remove_event_listener(self, event_listener: EventListener) -> None:
if event_listener in self._event_listeners:
self._event_listeners.remove(event_listener)
with self._thread_lock:
if event_listener in self._event_listeners:
self._event_listeners.remove(event_listener)

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
for event_listener in self._event_listeners:
Expand Down
9 changes: 4 additions & 5 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import uuid
from typing import TYPE_CHECKING, Callable, Optional

from attrs import Factory, define, field
Expand All @@ -12,6 +13,7 @@

@define
class EventListener:
id: str = field(default=Factory(lambda: uuid.uuid4().hex), metadata={"serializable": True}, kw_only=True)
handler: Callable[[BaseEvent], Optional[dict]] = field(default=Factory(lambda: lambda event: event.to_dict()))
event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True)
driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True)
Expand All @@ -21,17 +23,14 @@ class EventListener:
def __enter__(self) -> EventListener:
from griptape.events import EventBus

self._last_event_listeners = [*EventBus.event_listeners]

EventBus.set_event_listeners([self])
EventBus.add_event_listener(self)

return self

def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002
from griptape.events import EventBus

if self._last_event_listeners is not None:
EventBus.set_event_listeners(self._last_event_listeners)
EventBus.remove_event_listener(self)

self._last_event_listeners = None

Expand Down
23 changes: 16 additions & 7 deletions tests/unit/events/test_event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_add_remove_event_listener(self, pipeline):
mock1 = Mock()
mock2 = Mock()
# duplicate event listeners will only get added once
event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent]))
EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent]))
event_listener_1 = EventBus.add_event_listener(EventListener(mock1, id="1", event_types=[StartPromptEvent]))
EventBus.add_event_listener(EventListener(mock1, id="1", event_types=[StartPromptEvent]))

event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent]))
event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent]))
Expand Down Expand Up @@ -137,10 +137,19 @@ def event_handler(event: BaseEvent):
mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False)

def test_context_manager(self):
EventBus.add_event_listeners([EventListener()])
last_event_listeners = EventBus.event_listeners
e1 = EventListener()
EventBus.add_event_listeners([e1])

with EventListener() as e:
assert EventBus.event_listeners == [e]
with EventListener() as e2:
assert EventBus.event_listeners == [e1, e2]

assert EventBus.event_listeners == last_event_listeners
assert EventBus.event_listeners == [e1]

def test_context_manager_multiple(self):
e1 = EventListener()
EventBus.add_event_listener(e1)

with EventListener() as e2, EventListener() as e3:
assert EventBus.event_listeners == [e1, e2, e3]

assert EventBus.event_listeners == [e1]

0 comments on commit 38ad801

Please sign in to comment.