diff --git a/griptape/drivers/structure_run/local_structure_run_driver.py b/griptape/drivers/structure_run/local_structure_run_driver.py index e2080de93b..543bf785cc 100644 --- a/griptape/drivers/structure_run/local_structure_run_driver.py +++ b/griptape/drivers/structure_run/local_structure_run_driver.py @@ -8,6 +8,7 @@ from griptape.artifacts import BaseArtifact, InfoArtifact from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver +from griptape.events import EventBus if TYPE_CHECKING: from griptape.events import EventListener @@ -28,6 +29,7 @@ class LocalStructureRunDriver(BaseStructureRunDriver): def try_run(self, *args: BaseArtifact) -> BaseArtifact: old_env = os.environ.copy() + old_event_listeners = EventBus.event_listeners.copy() try: os.environ.update(self.env) @@ -36,6 +38,7 @@ def try_run(self, *args: BaseArtifact) -> BaseArtifact: stack.enter_context(event_listener) structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) finally: + EventBus.set_event_listeners(old_event_listeners) os.environ.clear() os.environ.update(old_env) diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index b7954480e1..5491678429 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -43,6 +43,10 @@ def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self._event_listeners: event_listener.publish_event(event, flush=flush) + def set_event_listeners(self, event_listeners: list[EventListener]) -> None: + with self._thread_lock: + self._event_listeners = event_listeners + def clear_event_listeners(self) -> None: with self._thread_lock: self._event_listeners.clear() diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index 1fad4a1dee..112987828b 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -16,8 +16,6 @@ class EventListener: event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True) driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) - _last_event_listeners: Optional[list[EventListener]] = field(default=None) - def __enter__(self) -> EventListener: from griptape.events import EventBus @@ -30,8 +28,6 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002 EventBus.remove_event_listener(self) - self._last_event_listeners = None - def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: event_types = self.event_types diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index cc432dafb4..8dbf4a7d6a 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -36,6 +36,11 @@ def test_remove_event_listener(self): def test_remove_unknown_event_listener(self): EventBus.remove_event_listener(EventListener()) + def test_set_event_listeners(self): + listeners = [EventListener(), EventListener()] + EventBus.set_event_listeners(listeners) + assert EventBus.event_listeners == listeners + def test_publish_event(self): # Given mock_handler = Mock()