diff --git a/CHANGELOG.md b/CHANGELOG.md index 525817db9..2b2dafbc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## [0.30.2] - 2024-08-26 + +### Fixed +- Ensure thread safety when publishing events by adding a thread lock to batch operations in `BaseEventListenerDriver`. + ## [0.30.1] - 2024-08-21 ### Fixed diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index 0af57f0f3..75bdc9f75 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import threading from abc import ABC, abstractmethod from typing import TYPE_CHECKING @@ -18,6 +19,7 @@ class BaseEventListenerDriver(FuturesExecutorMixin, ABC): batched: bool = field(default=True, kw_only=True) batch_size: int = field(default=10, kw_only=True) + thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock())) _batch: list[dict] = field(default=Factory(list), kw_only=True) @@ -39,10 +41,11 @@ def _safe_try_publish_event(self, event: BaseEvent | dict, *, flush: bool) -> No event_payload = event if isinstance(event, dict) else event.to_dict() if self.batched: - self._batch.append(event_payload) - if len(self.batch) >= self.batch_size or flush: - self.try_publish_event_payload_batch(self.batch) - self._batch = [] + with self.thread_lock: + self._batch.append(event_payload) + if len(self.batch) >= self.batch_size or flush: + self.try_publish_event_payload_batch(self.batch) + self._batch = [] return else: self.try_publish_event_payload(event_payload)