Skip to content

Commit

Permalink
Support batched events (#777)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored May 14, 2024
1 parent a899f5a commit 307ac2b
Show file tree
Hide file tree
Showing 17 changed files with 126 additions and 18 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## [0.25.1] - 2024-05-09
### Added
- Optional event batching on Event Listener Drivers.
- `id` field to all events.

### Changed
- Default behavior of Event Listener Drivers to batch events.

## [0.25.0] - 2024-05-06

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ class AmazonSqsEventListenerDriver(BaseEventListenerDriver):

def try_publish_event_payload(self, event_payload: dict) -> None:
self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
entries = [
{"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
for event_payload in event_payload_batch
]

self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ class AwsIotCoreEventListenerDriver(BaseEventListenerDriver):

def try_publish_event_payload(self, event_payload: dict) -> None:
self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))
32 changes: 25 additions & 7 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,37 @@
@define
class BaseEventListenerDriver(ABC):
futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)
batched: bool = field(default=True, kw_only=True)
batch_size: int = field(default=10, kw_only=True)

def publish_event(self, event: BaseEvent | dict) -> None:
if isinstance(event, dict):
self.futures_executor.submit(self._safe_try_publish_event_payload, event)
else:
self.futures_executor.submit(self._safe_try_publish_event_payload, event.to_dict())
_batch: list[dict] = field(default=Factory(list), kw_only=True)

@property
def batch(self) -> list[dict]:
return self._batch

def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None:
self.futures_executor.submit(self._safe_try_publish_event, event, flush)

@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None:
...

def _safe_try_publish_event_payload(self, event_payload: dict) -> None:
@abstractmethod
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
...

def _safe_try_publish_event(self, event: BaseEvent | dict, flush: bool) -> None:
try:
self.try_publish_event_payload(event_payload)
event_payload = event if isinstance(event, dict) else event.to_dict()

if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size or flush:
self.try_publish_event_payload_batch(self.batch)
self._batch = []
return
else:
self.try_publish_event_payload(event_payload)
except Exception as e:
logger.error(e)
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,9 @@ def try_publish_event_payload(self, event_payload: dict) -> None:

response = requests.post(url=url, json=event_payload, headers=self.headers)
response.raise_for_status()

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events")

response = requests.post(url=url, json=event_payload_batch, headers=self.headers)
response.raise_for_status()
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ class WebhookEventListenerDriver(BaseEventListenerDriver):
def try_publish_event_payload(self, event_payload: dict) -> None:
response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers)
response.raise_for_status()

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers)
response.raise_for_status()
6 changes: 5 additions & 1 deletion griptape/events/base_event.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

import time
import uuid
from abc import ABC
from attr import define, field, Factory

from attr import Factory, define, field

from griptape.mixins import SerializableMixin


@define
class BaseEvent(SerializableMixin, ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
timestamp: float = field(default=Factory(lambda: time.time()), kw_only=True, metadata={"serializable": True})
6 changes: 3 additions & 3 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ class EventListener:
event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True)
driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True)

def publish_event(self, event: BaseEvent) -> None:
def publish_event(self, event: BaseEvent, flush: bool = False) -> None:
event_types = self.event_types

if event_types is None or type(event) in event_types:
event_payload = self.handler(event)
if self.driver is not None:
if event_payload is not None and isinstance(event_payload, dict):
self.driver.publish_event(event_payload)
self.driver.publish_event(event_payload, flush=flush)
else:
self.driver.publish_event(event)
self.driver.publish_event(event, flush=flush)
7 changes: 4 additions & 3 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def remove_event_listener(self, event_listener: EventListener) -> None:
else:
raise ValueError("Event Listener not found.")

def publish_event(self, event: BaseEvent) -> None:
def publish_event(self, event: BaseEvent, flush: bool = False) -> None:
for event_listener in self.event_listeners:
event_listener.publish_event(event)
event_listener.publish_event(event, flush)

def context(self, task: BaseTask) -> dict[str, Any]:
return {"args": self.execution_args, "structure": self}
Expand All @@ -269,7 +269,8 @@ def after_run(self) -> None:
structure_id=self.id,
output_task_input=self.output_task.input,
output_task_output=self.output_task.output,
)
),
flush=True,
)

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/mocks/mock_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

class MockEvent(BaseEvent):
def to_dict(self) -> dict:
return {"timestamp": self.timestamp}
return {"timestamp": self.timestamp, "id": self.id}
5 changes: 4 additions & 1 deletion tests/mocks/mock_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@
@define
class MockEventListenerDriver(BaseEventListenerDriver):
def try_publish_event_payload(self, event_payload: dict) -> None:
...
pass

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ def test_init(self, driver):

def test_try_publish_event_payload(self, driver):
driver.try_publish_event_payload(MockEvent().to_dict())

def test_try_publish_event_payload_batch(self, driver):
driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)])
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ def test_init(self, driver):

def test_try_publish_event_payload(self, driver):
driver.try_publish_event_payload(MockEvent().to_dict())

def test_try_publish_event_payload_batch(self, driver):
driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from tests.mocks.mock_event import MockEvent
from tests.mocks.mock_event_listener_driver import MockEventListenerDriver


class TestBaseEventListenerDriver:
def test__safe_try_publish_event(self):
driver = MockEventListenerDriver(batched=False)

for _ in range(4):
driver._safe_try_publish_event(MockEvent().to_dict(), flush=False)
assert len(driver.batch) == 0

def test__safe_try_publish_event_batch(self):
driver = MockEventListenerDriver(batched=True)

for _ in range(0, 3):
driver._safe_try_publish_event(MockEvent().to_dict(), flush=False)
assert len(driver.batch) == 3

def test__safe_try_publish_event_batch_flush(self):
driver = MockEventListenerDriver(batched=True)

for _ in range(0, 3):
driver._safe_try_publish_event(MockEvent().to_dict(), flush=True)
assert len(driver.batch) == 0
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def test_try_publish_event_payload(self, mock_post, driver):
headers={"Authorization": "Bearer foo bar"},
)

def try_publish_event_payload_batch(self, mock_post, driver):
for _ in range(3):
event = MockEvent()
driver.try_publish_event_payload(event.to_dict())

mock_post.assert_called_with(
url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events",
json=event.to_dict(),
headers={"Authorization": "Bearer foo bar"},
)

def test_no_structure_run_id(self):
with pytest.raises(ValueError):
GriptapeCloudEventListenerDriver(api_key="foo bar")
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,14 @@ def test_try_publish_event_payload(self, mock_post):
mock_post.assert_called_once_with(
url="foo bar", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"}
)

def test_try_publish_event_payload_batch(self, mock_post):
driver = WebhookEventListenerDriver(webhook_url="foo bar", headers={"Authorization": "Bearer foo bar"})

for _ in range(3):
event = MockEvent()
driver.try_publish_event_payload(event.to_dict())

mock_post.assert_called_with(
url="foo bar", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"}
)
4 changes: 2 additions & 2 deletions tests/unit/events/test_event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def event_handler(_: BaseEvent):
event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent])
event_listener.publish_event(mock_event)

mock_event_listener_driver.publish_event.assert_called_once_with(mock_event)
mock_event_listener_driver.publish_event.assert_called_once_with(mock_event, flush=False)

def test_publish_transformed_event(self):
mock_event_listener_driver = Mock()
Expand All @@ -127,4 +127,4 @@ def event_handler(event: BaseEvent):
event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent])
event_listener.publish_event(mock_event)

mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()})
mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False)

0 comments on commit 307ac2b

Please sign in to comment.