Skip to content

Commit

Permalink
Use generics for event listener types
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 15, 2024
1 parent b4070eb commit 12a05ce
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Structures not flushing events when not listening for `FinishStructureRunEvent`.
- `EventListener.event_types` and the argument to `BaseEventListenerDriver.handler` being out of sync.

## [0.33.1] - 2024-10-11
## \[0.33.1\] - 2024-10-11

### Fixed

- Pinned `cohere` at `~5.11.0` to resolve slow dependency resolution.
- Missing `exa-py` from `all` extra.

## [0.33.0] - 2024-10-09
## \[0.33.0\] - 2024-10-09

## Added

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/misc/src/events_no_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def handler_maybe_drop_events(event: FinishStructureRunEvent) -> Optional[BaseEv
EventBus.add_event_listeners(
[
EventListener(
handler_maybe_drop_events, # pyright: ignore[reportArgumentType]
handler_maybe_drop_events,
event_types=[FinishStructureRunEvent],
# By default, GriptapeCloudEventListenerDriver uses the api key provided
# in the GT_CLOUD_API_KEY environment variable.
Expand Down
15 changes: 9 additions & 6 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar

from attrs import define, field

from .base_event import BaseEvent

if TYPE_CHECKING:
from griptape.drivers import BaseEventListenerDriver

from .base_event import BaseEvent

T = TypeVar("T", bound=BaseEvent)


@define
class EventListener:
class EventListener(Generic[T]):
"""An event listener that listens for events and handles them.
Args:
Expand All @@ -23,8 +26,8 @@ class EventListener:
event_listener_driver: The driver that will be used to publish events.
"""

handler: Optional[Callable[[BaseEvent], Optional[BaseEvent | dict]]] = field(default=None)
event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True)
handler: Optional[Callable[[T], Optional[BaseEvent | dict]]] = field(default=None)
event_types: Optional[list[type[T]]] = field(default=None, kw_only=True)
event_listener_driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True)

_last_event_listeners: Optional[list[EventListener]] = field(default=None)
Expand All @@ -43,7 +46,7 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002

self._last_event_listeners = None

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
def publish_event(self, event: T, *, flush: bool = False) -> None:
event_types = self.event_types

if event_types is None or type(event) in event_types:
Expand Down

0 comments on commit 12a05ce

Please sign in to comment.