Skip to content

Commit

Permalink
Add retry to event listener driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 10, 2024
1 parent 88fda00 commit a5de0e3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `griptape.configs.logging.JsonFormatter` for formatting logs as JSON.
- Request/response debug logging to all Prompt Drivers.
- `BaseEventListener.flush_events()` to flush events from an Event Listener.
- Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing.

### Changed

Expand Down
38 changes: 24 additions & 14 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from attrs import Factory, define, field

from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin

if TYPE_CHECKING:
Expand All @@ -15,7 +16,7 @@


@define
class BaseEventListenerDriver(FuturesExecutorMixin, ABC):
class BaseEventListenerDriver(FuturesExecutorMixin, ExponentialBackoffMixin, ABC):
batched: bool = field(default=True, kw_only=True)
batch_size: int = field(default=10, kw_only=True)

Expand All @@ -28,26 +29,35 @@ def batch(self) -> list[dict]:
def publish_event(self, event: BaseEvent | dict) -> None:
event_payload = event if isinstance(event, dict) else event.to_dict()

try:
if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
self._flush_events()
else:
self.futures_executor.submit(self.try_publish_event_payload, event_payload)
except Exception as e:
logger.error(e)
if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch)
self._batch = []
else:
self.futures_executor.submit(self._safe_publish_event_payload, event_payload)

def flush_events(self) -> None:
if self.batch:
self._flush_events()
self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch)
self._batch = []

@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None: ...

@abstractmethod
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ...

def _flush_events(self) -> None:
self.futures_executor.submit(self.try_publish_event_payload_batch, self.batch)
self._batch = []
def _safe_publish_event_payload(self, event_payload: dict) -> None:
for attempt in self.retrying():
with attempt:
self.try_publish_event_payload(event_payload)
else:
logger.error("event listener driver failed after all retry attempts")

def _safe_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
for attempt in self.retrying():
with attempt:
self.try_publish_event_payload_batch(event_payload_batch)
else:
logger.error("event listener driver failed after all retry attempts")
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,38 @@

class TestBaseEventListenerDriver:
def test_publish_event_no_batched(self):
driver = MockEventListenerDriver(batched=False)
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=False, futures_executor=executor)
driver.try_publish_event_payload = MagicMock(side_effect=driver.try_publish_event_payload)
mock_event_payload = MockEvent().to_dict()

driver.publish_event(MockEvent().to_dict())
driver.publish_event(mock_event_payload)

driver.try_publish_event_payload.assert_called_once()
executor.submit.assert_called_once_with(driver._safe_publish_event_payload, mock_event_payload)

def test_publish_event_yes_batched(self):
driver = MockEventListenerDriver(batched=True)
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=True, futures_executor=executor)
driver.try_publish_event_payload_batch = MagicMock(side_effect=driver.try_publish_event_payload)
mock_event_payload = MockEvent().to_dict()

for _ in range(0, 9):
driver.publish_event(MockEvent().to_dict())
mock_event_payloads = [mock_event_payload for _ in range(0, 9)]
for mock_event_payload in mock_event_payloads:
driver.publish_event(mock_event_payload)

assert len(driver._batch) == 9
executor.submit.assert_not_called()
driver.try_publish_event_payload_batch.assert_not_called()

# Publish the 10th event to trigger the batch publish
driver.publish_event(MockEvent().to_dict())
driver.publish_event(mock_event_payload)

assert len(driver._batch) == 0
driver.try_publish_event_payload_batch.assert_called_once()
executor.submit.assert_called_once_with(
driver._safe_publish_event_payload_batch, [*mock_event_payloads, mock_event_payload]
)

def test_flush_events(self):
driver = MockEventListenerDriver(batched=True)
Expand Down

0 comments on commit a5de0e3

Please sign in to comment.