From 9f8927d7361b6d275b9da71ebf930f85675005c8 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 10:28:22 -0700 Subject: [PATCH] Add `LocalStructureRunDriver.event_listeners` to allow passing Event Listeners to be active for a Structure's run --- CHANGELOG.md | 2 +- .../structure_run/local_structure_run_driver.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 788dd2e236..42ed7f9948 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Parameter `meta: dict` on `BaseEvent`. - `AzureOpenAiTextToSpeechDriver`. -- Ability to use Event Listeners as Context Managers for temporarily setting the Event Bus listeners. +- `LocalStructureRunDriver.event_listeners` for adding Event Listeners to a local Structure run. ### Changed - **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`. diff --git a/griptape/drivers/structure_run/local_structure_run_driver.py b/griptape/drivers/structure_run/local_structure_run_driver.py index c0049b29aa..e2080de93b 100644 --- a/griptape/drivers/structure_run/local_structure_run_driver.py +++ b/griptape/drivers/structure_run/local_structure_run_driver.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from contextlib import ExitStack from typing import TYPE_CHECKING, Callable from attrs import define, field @@ -9,18 +10,31 @@ from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver if TYPE_CHECKING: + from griptape.events import EventListener from griptape.structures import Structure @define class LocalStructureRunDriver(BaseStructureRunDriver): + """Runs a structure locally. + + Attributes: + structure_factory_fn: A function that returns a Structure. + event_listeners: A list of Event Listeners to add to the Event Bus for the Structure's run. + """ + structure_factory_fn: Callable[[], Structure] = field(kw_only=True) + event_listeners: list[EventListener] = field(factory=list, kw_only=True) def try_run(self, *args: BaseArtifact) -> BaseArtifact: old_env = os.environ.copy() try: os.environ.update(self.env) - structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) + + with ExitStack() as stack: + for event_listener in self.event_listeners: + stack.enter_context(event_listener) + structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) finally: os.environ.clear() os.environ.update(old_env)