Skip to content

Commit

Permalink
Add with_context decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 25, 2024
1 parent 3542d06 commit 6f38c07
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files.
- Events `BaseChunkEvent`, `TextChunkEvent`, `ActionChunkEvent`.
- `wrapt` dependency for more robust decorators.
- `griptape.utils.decorators.copy_contextvars` decorator for running functions with the current `contextvars` context.

### Changed

Expand Down Expand Up @@ -51,6 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Models in `ToolkitTask` with native tool calling no longer need to provide their final answer as `Answer:`.
- `EventListener.event_types` will now listen on child types of any provided type.
- Only install Tool dependencies if the Tool provides a `requirements.txt` and the dependencies are not already met.
- `EventBus`'s Event Listeners are now thread/coroutine-local. Event Listeners from the spawning thread will be automatically copied when using concurrent griptape features like Workflows.

### Fixed

Expand All @@ -59,6 +61,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Models occasionally hallucinating `memory_name` and `artifact_namespace` into Tool schemas when using `ToolkitTask`.
- Models occasionally providing overly succinct final answers when using `ToolkitTask`.
- Exception getting raised in `FuturesExecutorMixin.__del__`.
- Issues when using `EventListener` as a context manager in a multi-threaded environment.

## \[0.33.1\] - 2024-10-11

Expand Down
7 changes: 4 additions & 3 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin
from griptape.utils.decorators import copy_contextvars

if TYPE_CHECKING:
from griptape.events import BaseEvent
Expand All @@ -32,14 +33,14 @@ def publish_event(self, event: BaseEvent | dict) -> None:
if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch)
self.futures_executor.submit(copy_contextvars(self._safe_publish_event_payload_batch), self.batch)
self._batch = []
else:
self.futures_executor.submit(self._safe_publish_event_payload, event_payload)
self.futures_executor.submit(copy_contextvars(self._safe_publish_event_payload), event_payload)

def flush_events(self) -> None:
if self.batch:
self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch)
self.futures_executor.submit(copy_contextvars(self._safe_publish_event_payload_batch), self.batch)
self._batch = []

@abstractmethod
Expand Down
7 changes: 5 additions & 2 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin
from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.utils.decorators import copy_contextvars

if TYPE_CHECKING:
from griptape.drivers import BaseEmbeddingDriver
Expand Down Expand Up @@ -47,7 +48,9 @@ def upsert_text_artifacts(
if isinstance(artifacts, list):
return utils.execute_futures_list(
[
self.futures_executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs)
self.futures_executor.submit(
copy_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs
)
for a in artifacts
],
)
Expand All @@ -61,7 +64,7 @@ def upsert_text_artifacts(

futures_dict[namespace].append(
self.futures_executor.submit(
self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs
copy_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs
)
)

Expand Down
3 changes: 2 additions & 1 deletion griptape/engines/rag/stages/response_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape import utils
from griptape.engines.rag.stages import BaseRagStage
from griptape.utils.decorators import copy_contextvars

if TYPE_CHECKING:
from griptape.engines.rag import RagContext
Expand All @@ -32,7 +33,7 @@ def run(self, context: RagContext) -> RagContext:
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules))

results = utils.execute_futures_list(
[self.futures_executor.submit(r.run, context) for r in self.response_modules]
[self.futures_executor.submit(copy_contextvars(r.run), context) for r in self.response_modules]
)

context.outputs = results
Expand Down
3 changes: 2 additions & 1 deletion griptape/engines/rag/stages/retrieval_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from griptape import utils
from griptape.artifacts import TextArtifact
from griptape.engines.rag.stages import BaseRagStage
from griptape.utils.decorators import copy_contextvars

if TYPE_CHECKING:
from griptape.engines.rag import RagContext
Expand Down Expand Up @@ -36,7 +37,7 @@ def run(self, context: RagContext) -> RagContext:
logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules))

results = utils.execute_futures_list(
[self.futures_executor.submit(r.run, context) for r in self.retrieval_modules]
[self.futures_executor.submit(copy_contextvars(r.run), context) for r in self.retrieval_modules]
)

# flatten the list of lists
Expand Down
23 changes: 15 additions & 8 deletions griptape/events/event_bus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextvars
import threading
from typing import TYPE_CHECKING

Expand All @@ -11,14 +12,20 @@
from griptape.events import BaseEvent, EventListener


_event_listeners: contextvars.ContextVar[list[EventListener]] = contextvars.ContextVar("event_listeners", default=[])


@define
class _EventBus(SingletonMixin):
_event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners")
_thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()), alias="_thread_lock")

@property
def event_listeners(self) -> list[EventListener]:
return self._event_listeners
return _event_listeners.get()

@event_listeners.setter
def event_listeners(self, event_listeners: list[EventListener]) -> None:
_event_listeners.set(event_listeners)

def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]:
return [self.add_event_listener(event_listener) for event_listener in event_listeners]
Expand All @@ -29,23 +36,23 @@ def remove_event_listeners(self, event_listeners: list[EventListener]) -> None:

def add_event_listener(self, event_listener: EventListener) -> EventListener:
with self._thread_lock:
if event_listener not in self._event_listeners:
self._event_listeners.append(event_listener)
if event_listener not in self.event_listeners:
self.event_listeners = self.event_listeners + [event_listener]

return event_listener

def remove_event_listener(self, event_listener: EventListener) -> None:
with self._thread_lock:
if event_listener in self._event_listeners:
self._event_listeners.remove(event_listener)
if event_listener in self.event_listeners:
self.event_listeners = [listener for listener in self.event_listeners if listener != event_listener]

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
for event_listener in self._event_listeners:
for event_listener in self.event_listeners:
event_listener.publish_event(event, flush=flush)

def clear_event_listeners(self) -> None:
with self._thread_lock:
self._event_listeners.clear()
self.event_listeners = []


EventBus = _EventBus()
4 changes: 0 additions & 4 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class EventListener(Generic[T]):
event_types: Optional[list[type[T]]] = field(default=None, kw_only=True)
event_listener_driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True)

_last_event_listeners: Optional[list[EventListener]] = field(default=None)

def __enter__(self) -> EventListener:
from griptape.events import EventBus

Expand All @@ -44,8 +42,6 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002

EventBus.remove_event_listener(self)

self._last_event_listeners = None

def publish_event(self, event: T, *, flush: bool = False) -> None:
event_types = self.event_types

Expand Down
6 changes: 5 additions & 1 deletion griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.artifacts import BaseArtifact
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin
from griptape.utils.decorators import copy_contextvars
from griptape.utils.futures import execute_futures_dict
from griptape.utils.hash import bytes_to_hash, str_to_hash

Expand Down Expand Up @@ -61,7 +62,10 @@ def load_collection(
sources_by_key = {self.to_key(source): source for source in sources}

return execute_futures_dict(
{key: self.futures_executor.submit(self.load, source) for key, source in sources_by_key.items()},
{
key: self.futures_executor.submit(copy_contextvars(self.load), source)
for key, source in sources_by_key.items()
},
)

def to_key(self, source: S) -> str:
Expand Down
3 changes: 2 additions & 1 deletion griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from griptape.common import observable
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin
from griptape.structures import Structure
from griptape.utils.decorators import copy_contextvars

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
Expand Down Expand Up @@ -108,7 +109,7 @@ def try_run(self, *args) -> Workflow:

for task in ordered_tasks:
if task.can_execute():
future = self.futures_executor.submit(task.execute)
future = self.futures_executor.submit(copy_contextvars(task.execute))
futures_list[future] = task

# Wait for all tasks to complete
Expand Down
5 changes: 4 additions & 1 deletion griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin
from griptape.tasks import BaseTask
from griptape.utils import remove_null_values_in_dict_recursively
from griptape.utils.decorators import copy_contextvars

if TYPE_CHECKING:
from griptape.memory import TaskMemory
Expand Down Expand Up @@ -139,7 +140,9 @@ def run(self) -> BaseArtifact:
return ErrorArtifact("no tool output")

def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
return utils.execute_futures_list([self.futures_executor.submit(self.execute_action, a) for a in actions])
return utils.execute_futures_list(
[self.futures_executor.submit(copy_contextvars(self.execute_action), a) for a in actions]
)

def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]:
if action.tool is not None:
Expand Down
12 changes: 12 additions & 0 deletions griptape/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import contextvars
import functools
import inspect
from typing import Any, Callable, Optional

import schema
import wrapt
from schema import Schema

CONFIG_SCHEMA = Schema(
Expand All @@ -15,6 +17,16 @@
)


def copy_contextvars(wrapped: Callable) -> Callable:
ctx = contextvars.copy_context()

@wrapt.decorator
def wrapper(wrapped: Callable, instance: Any, args: tuple, kwargs: dict) -> Any:
return ctx.run(wrapped, *args, **kwargs)

return wrapper(wrapped) # pyright: ignore[reportCallIssue]


def activity(config: dict) -> Any:
validated_config = CONFIG_SCHEMA.validate(config)

Expand Down
23 changes: 23 additions & 0 deletions tests/unit/events/test_event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from griptape.events.base_event import BaseEvent
from griptape.structures import Pipeline
from griptape.tasks import ActionsSubtask, ToolkitTask
from griptape.utils.decorators import copy_contextvars
from tests.mocks.mock_event import MockEvent
from tests.mocks.mock_event_listener_driver import MockEventListenerDriver
from tests.mocks.mock_prompt_driver import MockPromptDriver
Expand Down Expand Up @@ -185,6 +186,28 @@ def test_context_manager_multiple(self):

assert EventBus.event_listeners == [e1]

def test_threaded(self):
from concurrent import futures

thread_pool_executor = futures.ThreadPoolExecutor()

e1 = EventListener(lambda e: e)
EventBus.add_event_listener(e1)

def handler() -> None:
e2 = EventListener(lambda e: e)
EventBus.add_event_listener(e2)
assert EventBus.event_listeners == [e1, e2]
EventBus.remove_event_listener(e2)
assert EventBus.event_listeners == [e1]
EventBus.clear_event_listeners()
assert EventBus.event_listeners == []
EventBus.add_event_listener(e2)

thread_pool_executor.submit(copy_contextvars(handler)).result()

assert EventBus.event_listeners == [e1]

def test_publish_event_yes_flush(self):
mock_event_listener_driver = MockEventListenerDriver()
mock_event_listener_driver.flush_events = Mock(side_effect=mock_event_listener_driver.flush_events)
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/utils/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import contextvars
import threading

from griptape.utils.decorators import copy_contextvars


class TestDecorators:
def test_copy_contextvars_decorator(self):
context_var = contextvars.ContextVar("context_var")
context_var.set("test")

def undecorated_function(vals: list) -> None:
vals.append(context_var.get())

@copy_contextvars
def decorated_function(vals: list) -> None:
vals.append(context_var.get())

return_values = []
thread = threading.Thread(target=decorated_function, args=(return_values,))
thread.start()
thread.join()

assert return_values == ["test"]

return_values = []
thread = threading.Thread(target=undecorated_function, args=(return_values,))
thread.start()
thread.join()

assert return_values == []

def test_copy_contextvars_direct(self):
context_var = contextvars.ContextVar("context_var")
context_var.set("test")

def function(vals: list) -> None:
vals.append(context_var.get())

decoratored_function = copy_contextvars(function)

return_values = []
thread = threading.Thread(target=decoratored_function, args=(return_values,))
thread.start()
thread.join()

assert return_values == ["test"]

return_values = []
thread = threading.Thread(target=function, args=(return_values,))
thread.start()
thread.join()

assert return_values == []

0 comments on commit 6f38c07

Please sign in to comment.