Skip to content

Commit

Permalink
Add LocalStructureRunDriver.event_listeners to allow passing Event …
Browse files Browse the repository at this point in the history
…Listeners to be active for a Structure's run
  • Loading branch information
collindutter committed Sep 11, 2024
1 parent 0b5a9fb commit 0bbccca
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Parameter `meta: dict` on `BaseEvent`.
- `AzureOpenAiTextToSpeechDriver`.
- Ability to use Event Listeners as Context Managers for temporarily setting the Event Bus listeners.
- `LocalStructureRunDriver.event_listeners` for adding Event Listeners to a local Structure run.

### Changed
- **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`.
Expand Down
16 changes: 15 additions & 1 deletion griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from contextlib import ExitStack
from typing import TYPE_CHECKING, Callable

from attrs import define, field
Expand All @@ -9,18 +10,31 @@
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver

if TYPE_CHECKING:
from griptape.events import EventListener
from griptape.structures import Structure


@define
class LocalStructureRunDriver(BaseStructureRunDriver):
"""Runs a structure locally.
Attributes:
structure_factory_fn: A function that returns a Structure.
event_listeners: A list of Event Listeners to add to the Event Bus for the Structure's run.
"""

structure_factory_fn: Callable[[], Structure] = field(kw_only=True)
event_listeners: list[EventListener] = field(factory=list, kw_only=True)

def try_run(self, *args: BaseArtifact) -> BaseArtifact:
old_env = os.environ.copy()
try:
os.environ.update(self.env)
structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])

with ExitStack() as stack:
for event_listener in self.event_listeners:
stack.enter_context(event_listener)

Check warning on line 36 in griptape/drivers/structure_run/local_structure_run_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/structure_run/local_structure_run_driver.py#L36

Added line #L36 was not covered by tests
structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])
finally:
os.environ.clear()
os.environ.update(old_env)
Expand Down
5 changes: 2 additions & 3 deletions griptape/events/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def add_event_listeners(self, event_listeners: list[EventListener]) -> list[Even
return [self.add_event_listener(event_listener) for event_listener in event_listeners]

def remove_event_listeners(self, event_listeners: list[EventListener]) -> None:
with self._thread_lock:
for event_listener in event_listeners:
self.remove_event_listener(event_listener)
for event_listener in event_listeners:
self.remove_event_listener(event_listener)

def add_event_listener(self, event_listener: EventListener) -> EventListener:
with self._thread_lock:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/events/test_event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_add_remove_event_listener(self, pipeline):
mock1 = Mock()
mock2 = Mock()
# duplicate event listeners will only get added once
event_listener_1 = EventBus.add_event_listener(EventListener(mock1, id="1", event_types=[StartPromptEvent]))
EventBus.add_event_listener(EventListener(mock1, id="1", event_types=[StartPromptEvent]))
event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent]))
EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent]))

event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent]))
event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent]))
Expand Down

0 comments on commit 0bbccca

Please sign in to comment.