diff --git a/CHANGELOG.md b/CHANGELOG.md index e31647b7b..632c58bee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## [0.25.1] - 2024-05-09 +### Added +- Optional event batching on Event Listener Drivers. +- `id` field to all events. + +### Changed +- Default behavior of Event Listener Drivers to batch events. + ## [0.25.0] - 2024-05-06 ### Added diff --git a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py index 24e3c9e1e..1c8132b67 100644 --- a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py +++ b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py @@ -20,3 +20,11 @@ class AmazonSqsEventListenerDriver(BaseEventListenerDriver): def try_publish_event_payload(self, event_payload: dict) -> None: self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + entries = [ + {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)} + for event_payload in event_payload_batch + ] + + self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries) diff --git a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py index 302fd91d5..c4fd72084 100644 --- a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py +++ b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py @@ -21,3 +21,6 @@ class AwsIotCoreEventListenerDriver(BaseEventListenerDriver): def try_publish_event_payload(self, event_payload: dict) -> None: self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload)) + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch)) diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index eec0fe320..8e7f827e9 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -14,19 +14,37 @@ @define class BaseEventListenerDriver(ABC): futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) + batched: bool = field(default=True, kw_only=True) + batch_size: int = field(default=10, kw_only=True) - def publish_event(self, event: BaseEvent | dict) -> None: - if isinstance(event, dict): - self.futures_executor.submit(self._safe_try_publish_event_payload, event) - else: - self.futures_executor.submit(self._safe_try_publish_event_payload, event.to_dict()) + _batch: list[dict] = field(default=Factory(list), kw_only=True) + + @property + def batch(self) -> list[dict]: + return self._batch + + def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None: + self.futures_executor.submit(self._safe_try_publish_event, event, flush) @abstractmethod def try_publish_event_payload(self, event_payload: dict) -> None: ... - def _safe_try_publish_event_payload(self, event_payload: dict) -> None: + @abstractmethod + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + ... + + def _safe_try_publish_event(self, event: BaseEvent | dict, flush: bool) -> None: try: - self.try_publish_event_payload(event_payload) + event_payload = event if isinstance(event, dict) else event.to_dict() + + if self.batched: + self._batch.append(event_payload) + if len(self.batch) >= self.batch_size or flush: + self.try_publish_event_payload_batch(self.batch) + self._batch = [] + return + else: + self.try_publish_event_payload(event_payload) except Exception as e: logger.error(e) diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py index 461f06be9..2c4149ae7 100644 --- a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -41,3 +41,9 @@ def try_publish_event_payload(self, event_payload: dict) -> None: response = requests.post(url=url, json=event_payload, headers=self.headers) response.raise_for_status() + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events") + + response = requests.post(url=url, json=event_payload_batch, headers=self.headers) + response.raise_for_status() diff --git a/griptape/drivers/event_listener/webhook_event_listener_driver.py b/griptape/drivers/event_listener/webhook_event_listener_driver.py index 3803c86b6..242e5428a 100644 --- a/griptape/drivers/event_listener/webhook_event_listener_driver.py +++ b/griptape/drivers/event_listener/webhook_event_listener_driver.py @@ -15,3 +15,7 @@ class WebhookEventListenerDriver(BaseEventListenerDriver): def try_publish_event_payload(self, event_payload: dict) -> None: response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers) response.raise_for_status() + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers) + response.raise_for_status() diff --git a/griptape/events/base_event.py b/griptape/events/base_event.py index d32defe96..48a48890e 100644 --- a/griptape/events/base_event.py +++ b/griptape/events/base_event.py @@ -1,11 +1,15 @@ from __future__ import annotations + import time +import uuid from abc import ABC -from attr import define, field, Factory + +from attr import Factory, define, field from griptape.mixins import SerializableMixin @define class BaseEvent(SerializableMixin, ABC): + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) timestamp: float = field(default=Factory(lambda: time.time()), kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index a6b692d4d..44d7b2d85 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -13,13 +13,13 @@ class EventListener: event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True) driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) - def publish_event(self, event: BaseEvent) -> None: + def publish_event(self, event: BaseEvent, flush: bool = False) -> None: event_types = self.event_types if event_types is None or type(event) in event_types: event_payload = self.handler(event) if self.driver is not None: if event_payload is not None and isinstance(event_payload, dict): - self.driver.publish_event(event_payload) + self.driver.publish_event(event_payload, flush=flush) else: - self.driver.publish_event(event) + self.driver.publish_event(event, flush=flush) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index ef9205db9..9cd28ab67 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -249,9 +249,9 @@ def remove_event_listener(self, event_listener: EventListener) -> None: else: raise ValueError("Event Listener not found.") - def publish_event(self, event: BaseEvent) -> None: + def publish_event(self, event: BaseEvent, flush: bool = False) -> None: for event_listener in self.event_listeners: - event_listener.publish_event(event) + event_listener.publish_event(event, flush) def context(self, task: BaseTask) -> dict[str, Any]: return {"args": self.execution_args, "structure": self} @@ -269,7 +269,8 @@ def after_run(self) -> None: structure_id=self.id, output_task_input=self.output_task.input, output_task_output=self.output_task.output, - ) + ), + flush=True, ) @abstractmethod diff --git a/tests/mocks/mock_event.py b/tests/mocks/mock_event.py index 651cf3ece..2b9d9ade3 100644 --- a/tests/mocks/mock_event.py +++ b/tests/mocks/mock_event.py @@ -3,4 +3,4 @@ class MockEvent(BaseEvent): def to_dict(self) -> dict: - return {"timestamp": self.timestamp} + return {"timestamp": self.timestamp, "id": self.id} diff --git a/tests/mocks/mock_event_listener_driver.py b/tests/mocks/mock_event_listener_driver.py index 3e0c173ca..dd54eeb73 100644 --- a/tests/mocks/mock_event_listener_driver.py +++ b/tests/mocks/mock_event_listener_driver.py @@ -6,4 +6,7 @@ @define class MockEventListenerDriver(BaseEventListenerDriver): def try_publish_event_payload(self, event_payload: dict) -> None: - ... + pass + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + pass diff --git a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py index e0a9e7b7c..706831d67 100644 --- a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py @@ -29,3 +29,6 @@ def test_init(self, driver): def test_try_publish_event_payload(self, driver): driver.try_publish_event_payload(MockEvent().to_dict()) + + def test_try_publish_event_payload_batch(self, driver): + driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)]) diff --git a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py index cd50ac82d..9a5fe9ec0 100644 --- a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py @@ -23,3 +23,6 @@ def test_init(self, driver): def test_try_publish_event_payload(self, driver): driver.try_publish_event_payload(MockEvent().to_dict()) + + def test_try_publish_event_payload_batch(self, driver): + driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)]) diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py new file mode 100644 index 000000000..6d33dd2a0 --- /dev/null +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -0,0 +1,25 @@ +from tests.mocks.mock_event import MockEvent +from tests.mocks.mock_event_listener_driver import MockEventListenerDriver + + +class TestBaseEventListenerDriver: + def test__safe_try_publish_event(self): + driver = MockEventListenerDriver(batched=False) + + for _ in range(4): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=False) + assert len(driver.batch) == 0 + + def test__safe_try_publish_event_batch(self): + driver = MockEventListenerDriver(batched=True) + + for _ in range(0, 3): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=False) + assert len(driver.batch) == 3 + + def test__safe_try_publish_event_batch_flush(self): + driver = MockEventListenerDriver(batched=True) + + for _ in range(0, 3): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=True) + assert len(driver.batch) == 0 diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py index 51f29ff71..d27f09ec8 100644 --- a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -39,6 +39,17 @@ def test_try_publish_event_payload(self, mock_post, driver): headers={"Authorization": "Bearer foo bar"}, ) + def try_publish_event_payload_batch(self, mock_post, driver): + for _ in range(3): + event = MockEvent() + driver.try_publish_event_payload(event.to_dict()) + + mock_post.assert_called_with( + url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", + json=event.to_dict(), + headers={"Authorization": "Bearer foo bar"}, + ) + def test_no_structure_run_id(self): with pytest.raises(ValueError): GriptapeCloudEventListenerDriver(api_key="foo bar") diff --git a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py index f3f872c0a..50021cbe3 100644 --- a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py @@ -23,3 +23,14 @@ def test_try_publish_event_payload(self, mock_post): mock_post.assert_called_once_with( url="foo bar", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"} ) + + def test_try_publish_event_payload_batch(self, mock_post): + driver = WebhookEventListenerDriver(webhook_url="foo bar", headers={"Authorization": "Bearer foo bar"}) + + for _ in range(3): + event = MockEvent() + driver.try_publish_event_payload(event.to_dict()) + + mock_post.assert_called_with( + url="foo bar", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"} + ) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index fcc9688ed..2f32837e0 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -114,7 +114,7 @@ def event_handler(_: BaseEvent): event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent]) event_listener.publish_event(mock_event) - mock_event_listener_driver.publish_event.assert_called_once_with(mock_event) + mock_event_listener_driver.publish_event.assert_called_once_with(mock_event, flush=False) def test_publish_transformed_event(self): mock_event_listener_driver = Mock() @@ -127,4 +127,4 @@ def event_handler(event: BaseEvent): event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent]) event_listener.publish_event(mock_event) - mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}) + mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False)