Skip to content

Commit

Permalink
Clear and restore event listeners in driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 11, 2024
1 parent d7e72ff commit bea4ff6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
3 changes: 3 additions & 0 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from griptape.artifacts import BaseArtifact, InfoArtifact
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver
from griptape.events import EventBus

if TYPE_CHECKING:
from griptape.events import EventListener
Expand All @@ -28,6 +29,7 @@ class LocalStructureRunDriver(BaseStructureRunDriver):

def try_run(self, *args: BaseArtifact) -> BaseArtifact:
old_env = os.environ.copy()
old_event_listeners = EventBus.event_listeners.copy()
try:
os.environ.update(self.env)

Expand All @@ -36,6 +38,7 @@ def try_run(self, *args: BaseArtifact) -> BaseArtifact:
stack.enter_context(event_listener)

Check warning on line 38 in griptape/drivers/structure_run/local_structure_run_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/structure_run/local_structure_run_driver.py#L38

Added line #L38 was not covered by tests
structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])
finally:
EventBus.set_event_listeners(old_event_listeners)
os.environ.clear()
os.environ.update(old_env)

Expand Down
4 changes: 4 additions & 0 deletions griptape/events/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
for event_listener in self._event_listeners:
event_listener.publish_event(event, flush=flush)

def set_event_listeners(self, event_listeners: list[EventListener]) -> None:
with self._thread_lock:
self._event_listeners = event_listeners

def clear_event_listeners(self) -> None:
with self._thread_lock:
self._event_listeners.clear()
Expand Down
4 changes: 0 additions & 4 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class EventListener:
event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True)
driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True)

_last_event_listeners: Optional[list[EventListener]] = field(default=None)

def __enter__(self) -> EventListener:
from griptape.events import EventBus

Expand All @@ -30,8 +28,6 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002

EventBus.remove_event_listener(self)

self._last_event_listeners = None

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
event_types = self.event_types

Expand Down
5 changes: 5 additions & 0 deletions tests/unit/events/test_event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def test_remove_event_listener(self):
def test_remove_unknown_event_listener(self):
EventBus.remove_event_listener(EventListener())

def test_set_event_listeners(self):
listeners = [EventListener(), EventListener()]
EventBus.set_event_listeners(listeners)
assert EventBus.event_listeners == listeners

def test_publish_event(self):
# Given
mock_handler = Mock()
Expand Down

0 comments on commit bea4ff6

Please sign in to comment.