From 342d8a3d9c5921fda2630f6a076008ad9391a81b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 22 Oct 2024 17:41:20 -0700 Subject: [PATCH] Add with_context decorator --- CHANGELOG.md | 3 ++ .../base_event_listener_driver.py | 7 +-- .../vector/base_vector_store_driver.py | 7 ++- .../engines/rag/stages/response_rag_stage.py | 3 +- .../engines/rag/stages/retrieval_rag_stage.py | 3 +- griptape/events/event_bus.py | 23 +++++--- griptape/events/event_listener.py | 4 -- griptape/loaders/base_loader.py | 6 ++- griptape/structures/workflow.py | 3 +- griptape/tasks/actions_subtask.py | 5 +- griptape/utils/decorators.py | 12 +++++ tests/unit/events/test_event_listener.py | 21 ++++++++ tests/unit/utils/test_decorators.py | 54 +++++++++++++++++++ 13 files changed, 129 insertions(+), 22 deletions(-) create mode 100644 tests/unit/utils/test_decorators.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b37bf7b180..7c7fbc70df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files. - Events `BaseChunkEvent`, `TextChunkEvent`, `ActionChunkEvent`. - `wrapt` dependency for more robust decorators. +- `griptape.utils.decorators.copy_contextvars` decorator for running functions with the current `contextvars` context. ### Changed @@ -51,6 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Models in `ToolkitTask` with native tool calling no longer need to provide their final answer as `Answer:`. - `EventListener.event_types` will now listen on child types of any provided type. - Only install Tool dependencies if the Tool provides a `requirements.txt` and the dependencies are not already met. +- `EventBus`'s Event Listeners are now thread/coroutine-local. Event Listeners from the spawning thread will be automatically copied when using concurrent griptape features like Workflows. ### Fixed @@ -59,6 +61,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Models occasionally hallucinating `memory_name` and `artifact_namespace` into Tool schemas when using `ToolkitTask`. - Models occasionally providing overly succinct final answers when using `ToolkitTask`. - Exception getting raised in `FuturesExecutorMixin.__del__`. +- Issues when using `EventListener` as a context manager in a multi-threaded environment. ## \[0.33.1\] - 2024-10-11 diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index f0b0cc7809..b42d5ed696 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -8,6 +8,7 @@ from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin +from griptape.utils.decorators import copy_contextvars if TYPE_CHECKING: from griptape.events import BaseEvent @@ -32,14 +33,14 @@ def publish_event(self, event: BaseEvent | dict) -> None: if self.batched: self._batch.append(event_payload) if len(self.batch) >= self.batch_size: - self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch) + self.futures_executor.submit(copy_contextvars(self._safe_publish_event_payload_batch), self.batch) self._batch = [] else: - self.futures_executor.submit(self._safe_publish_event_payload, event_payload) + self.futures_executor.submit(copy_contextvars(self._safe_publish_event_payload), event_payload) def flush_events(self) -> None: if self.batch: - self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch) + self.futures_executor.submit(copy_contextvars(self._safe_publish_event_payload_batch), self.batch) self._batch = [] @abstractmethod diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index e2a394bf49..e0c2eac601 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -11,6 +11,7 @@ from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin from griptape.mixins.serializable_mixin import SerializableMixin +from griptape.utils.decorators import copy_contextvars if TYPE_CHECKING: from griptape.drivers import BaseEmbeddingDriver @@ -47,7 +48,9 @@ def upsert_text_artifacts( if isinstance(artifacts, list): return utils.execute_futures_list( [ - self.futures_executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs) + self.futures_executor.submit( + copy_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs + ) for a in artifacts ], ) @@ -61,7 +64,7 @@ def upsert_text_artifacts( futures_dict[namespace].append( self.futures_executor.submit( - self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs + copy_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs ) ) diff --git a/griptape/engines/rag/stages/response_rag_stage.py b/griptape/engines/rag/stages/response_rag_stage.py index de286317ca..68f6c969ad 100644 --- a/griptape/engines/rag/stages/response_rag_stage.py +++ b/griptape/engines/rag/stages/response_rag_stage.py @@ -7,6 +7,7 @@ from griptape import utils from griptape.engines.rag.stages import BaseRagStage +from griptape.utils.decorators import copy_contextvars if TYPE_CHECKING: from griptape.engines.rag import RagContext @@ -32,7 +33,7 @@ def run(self, context: RagContext) -> RagContext: logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules)) results = utils.execute_futures_list( - [self.futures_executor.submit(r.run, context) for r in self.response_modules] + [self.futures_executor.submit(copy_contextvars(r.run), context) for r in self.response_modules] ) context.outputs = results diff --git a/griptape/engines/rag/stages/retrieval_rag_stage.py b/griptape/engines/rag/stages/retrieval_rag_stage.py index 6ce9fb19fb..0c42f13d4d 100644 --- a/griptape/engines/rag/stages/retrieval_rag_stage.py +++ b/griptape/engines/rag/stages/retrieval_rag_stage.py @@ -9,6 +9,7 @@ from griptape import utils from griptape.artifacts import TextArtifact from griptape.engines.rag.stages import BaseRagStage +from griptape.utils.decorators import copy_contextvars if TYPE_CHECKING: from griptape.engines.rag import RagContext @@ -36,7 +37,7 @@ def run(self, context: RagContext) -> RagContext: logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules)) results = utils.execute_futures_list( - [self.futures_executor.submit(r.run, context) for r in self.retrieval_modules] + [self.futures_executor.submit(copy_contextvars(r.run), context) for r in self.retrieval_modules] ) # flatten the list of lists diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index b7954480e1..3153ec88f8 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import threading from typing import TYPE_CHECKING @@ -11,14 +12,20 @@ from griptape.events import BaseEvent, EventListener +_event_listeners: contextvars.ContextVar[list[EventListener]] = contextvars.ContextVar("event_listeners", default=[]) + + @define class _EventBus(SingletonMixin): - _event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners") _thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()), alias="_thread_lock") @property def event_listeners(self) -> list[EventListener]: - return self._event_listeners + return _event_listeners.get() + + @event_listeners.setter + def event_listeners(self, event_listeners: list[EventListener]) -> None: + _event_listeners.set(event_listeners) def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: return [self.add_event_listener(event_listener) for event_listener in event_listeners] @@ -29,23 +36,23 @@ def remove_event_listeners(self, event_listeners: list[EventListener]) -> None: def add_event_listener(self, event_listener: EventListener) -> EventListener: with self._thread_lock: - if event_listener not in self._event_listeners: - self._event_listeners.append(event_listener) + if event_listener not in self.event_listeners: + self.event_listeners = self.event_listeners + [event_listener] return event_listener def remove_event_listener(self, event_listener: EventListener) -> None: with self._thread_lock: - if event_listener in self._event_listeners: - self._event_listeners.remove(event_listener) + if event_listener in self.event_listeners: + self.event_listeners = [listener for listener in self.event_listeners if listener != event_listener] def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: - for event_listener in self._event_listeners: + for event_listener in self.event_listeners: event_listener.publish_event(event, flush=flush) def clear_event_listeners(self) -> None: with self._thread_lock: - self._event_listeners.clear() + self.event_listeners.clear() EventBus = _EventBus() diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index df4a2668a0..bbca5f83b1 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -30,8 +30,6 @@ class EventListener(Generic[T]): event_types: Optional[list[type[T]]] = field(default=None, kw_only=True) event_listener_driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) - _last_event_listeners: Optional[list[EventListener]] = field(default=None) - def __enter__(self) -> EventListener: from griptape.events import EventBus @@ -44,8 +42,6 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002 EventBus.remove_event_listener(self) - self._last_event_listeners = None - def publish_event(self, event: T, *, flush: bool = False) -> None: event_types = self.event_types diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index f7340283b4..5f4fcee163 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -7,6 +7,7 @@ from griptape.artifacts import BaseArtifact from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin +from griptape.utils.decorators import copy_contextvars from griptape.utils.futures import execute_futures_dict from griptape.utils.hash import bytes_to_hash, str_to_hash @@ -61,7 +62,10 @@ def load_collection( sources_by_key = {self.to_key(source): source for source in sources} return execute_futures_dict( - {key: self.futures_executor.submit(self.load, source) for key, source in sources_by_key.items()}, + { + key: self.futures_executor.submit(copy_contextvars(self.load), source) + for key, source in sources_by_key.items() + }, ) def to_key(self, source: S) -> str: diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index 99af20dc29..280475061c 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -10,6 +10,7 @@ from griptape.common import observable from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin from griptape.structures import Structure +from griptape.utils.decorators import copy_contextvars if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -108,7 +109,7 @@ def try_run(self, *args) -> Workflow: for task in ordered_tasks: if task.can_execute(): - future = self.futures_executor.submit(task.execute) + future = self.futures_executor.submit(copy_contextvars(task.execute)) futures_list[future] = task # Wait for all tasks to complete diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 1b4ccfb7da..1cc545db12 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -16,6 +16,7 @@ from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively +from griptape.utils.decorators import copy_contextvars if TYPE_CHECKING: from griptape.memory import TaskMemory @@ -139,7 +140,9 @@ def run(self) -> BaseArtifact: return ErrorArtifact("no tool output") def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]: - return utils.execute_futures_list([self.futures_executor.submit(self.execute_action, a) for a in actions]) + return utils.execute_futures_list( + [self.futures_executor.submit(copy_contextvars(self.execute_action), a) for a in actions] + ) def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: if action.tool is not None: diff --git a/griptape/utils/decorators.py b/griptape/utils/decorators.py index 3eef6d8d0c..fe3d3a3f0b 100644 --- a/griptape/utils/decorators.py +++ b/griptape/utils/decorators.py @@ -1,10 +1,12 @@ from __future__ import annotations +import contextvars import functools import inspect from typing import Any, Callable, Optional import schema +import wrapt from schema import Schema CONFIG_SCHEMA = Schema( @@ -15,6 +17,16 @@ ) +def copy_contextvars(wrapped: Callable) -> Callable: + ctx = contextvars.copy_context() + + @wrapt.decorator + def wrapper(wrapped: Callable, instance: Any, args: tuple, kwargs: dict) -> Any: + return ctx.run(wrapped, *args, **kwargs) + + return wrapper(wrapped) # pyright: ignore[reportCallIssue] + + def activity(config: dict) -> Any: validated_config = CONFIG_SCHEMA.validate(config) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index b3aee2891c..d2bec0b00c 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -20,6 +20,7 @@ from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline from griptape.tasks import ActionsSubtask, ToolkitTask +from griptape.utils.decorators import copy_contextvars from tests.mocks.mock_event import MockEvent from tests.mocks.mock_event_listener_driver import MockEventListenerDriver from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -185,6 +186,26 @@ def test_context_manager_multiple(self): assert EventBus.event_listeners == [e1] + def test_threaded(self): + from concurrent import futures + + thread_pool_executor = futures.ThreadPoolExecutor() + + e1 = EventListener(lambda e: e) + EventBus.add_event_listener(e1) + + def handler() -> None: + e2 = EventListener(lambda e: e) + EventBus.add_event_listener(e2) + assert EventBus.event_listeners == [e1, e2] + EventBus.remove_event_listener(e2) + assert EventBus.event_listeners == [e1] + EventBus.add_event_listener(e2) + + thread_pool_executor.submit(copy_contextvars(handler)).result() + + assert EventBus.event_listeners == [e1] + def test_publish_event_yes_flush(self): mock_event_listener_driver = MockEventListenerDriver() mock_event_listener_driver.flush_events = Mock(side_effect=mock_event_listener_driver.flush_events) diff --git a/tests/unit/utils/test_decorators.py b/tests/unit/utils/test_decorators.py new file mode 100644 index 0000000000..0e7d21f29c --- /dev/null +++ b/tests/unit/utils/test_decorators.py @@ -0,0 +1,54 @@ +import contextvars +import threading + +from griptape.utils.decorators import copy_contextvars + + +class TestDecorators: + def test_copy_contextvars_decorator(self): + context_var = contextvars.ContextVar("context_var") + context_var.set("test") + + def undecorated_function(vals: list) -> None: + vals.append(context_var.get()) + + @copy_contextvars + def decorated_function(vals: list) -> None: + vals.append(context_var.get()) + + return_values = [] + thread = threading.Thread(target=decorated_function, args=(return_values,)) + thread.start() + thread.join() + + assert return_values == ["test"] + + return_values = [] + thread = threading.Thread(target=undecorated_function, args=(return_values,)) + thread.start() + thread.join() + + assert return_values == [] + + def test_copy_contextvars_direct(self): + context_var = contextvars.ContextVar("context_var") + context_var.set("test") + + def function(vals: list) -> None: + vals.append(context_var.get()) + + decoratored_function = copy_contextvars(function) + + return_values = [] + thread = threading.Thread(target=decoratored_function, args=(return_values,)) + thread.start() + thread.join() + + assert return_values == ["test"] + + return_values = [] + thread = threading.Thread(target=function, args=(return_values,)) + thread.start() + thread.join() + + assert return_values == []