Skip to content

Commit

Permalink
Use context manager instead of __del__ for futures executor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 6, 2025
1 parent 6cd7b67 commit c09fc8b
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 72 deletions.
18 changes: 10 additions & 8 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@ def batch(self) -> list[dict]:
def publish_event(self, event: BaseEvent | dict) -> None:
event_payload = event if isinstance(event, dict) else event.to_dict()

if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
self._batch = []
else:
self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload)
with self.create_futures_executor() as futures_executor:
if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
self._batch = []
else:
futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload)

def flush_events(self) -> None:
if self.batch:
self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
with self.create_futures_executor() as futures_executor:
futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
self._batch = []

@abstractmethod
Expand Down
45 changes: 23 additions & 22 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,31 @@ def upsert_text_artifacts(
meta: Optional[dict] = None,
**kwargs,
) -> list[str] | dict[str, list[str]]:
if isinstance(artifacts, list):
return utils.execute_futures_list(
[
self.futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs
)
for a in artifacts
],
)
else:
futures_dict = {}

for namespace, artifact_list in artifacts.items():
for a in artifact_list:
if not futures_dict.get(namespace):
futures_dict[namespace] = []

futures_dict[namespace].append(
self.futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs
with self.create_futures_executor() as futures_executor:
if isinstance(artifacts, list):
return utils.execute_futures_list(
[
futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs
)
for a in artifacts
],
)
else:
futures_dict = {}

for namespace, artifact_list in artifacts.items():
for a in artifact_list:
if not futures_dict.get(namespace):
futures_dict[namespace] = []

futures_dict[namespace].append(
futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs
)
)
)

return utils.execute_futures_list_dict(futures_dict)
return utils.execute_futures_list_dict(futures_dict)

def upsert_text_artifact(
self,
Expand Down
7 changes: 4 additions & 3 deletions griptape/engines/rag/stages/response_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ def modules(self) -> list[BaseRagModule]:
def run(self, context: RagContext) -> RagContext:
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules))

results = utils.execute_futures_list(
[self.futures_executor.submit(with_contextvars(r.run), context) for r in self.response_modules]
)
with self.create_futures_executor() as futures_executor:
results = utils.execute_futures_list(
[futures_executor.submit(with_contextvars(r.run), context) for r in self.response_modules]
)

context.outputs = results

Expand Down
7 changes: 4 additions & 3 deletions griptape/engines/rag/stages/retrieval_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def modules(self) -> list[BaseRagModule]:
def run(self, context: RagContext) -> RagContext:
logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules))

results = utils.execute_futures_list(
[self.futures_executor.submit(with_contextvars(r.run), context) for r in self.retrieval_modules]
)
with self.create_futures_executor() as futures_executor:
results = utils.execute_futures_list(
[futures_executor.submit(with_contextvars(r.run), context) for r in self.retrieval_modules]
)

# flatten the list of lists
results = list(itertools.chain.from_iterable(results))
Expand Down
13 changes: 7 additions & 6 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ def load_collection(
# to avoid duplicate work.
sources_by_key = {self.to_key(source): source for source in sources}

return execute_futures_dict(
{
key: self.futures_executor.submit(with_contextvars(self.load), source)
for key, source in sources_by_key.items()
},
)
with self.create_futures_executor() as futures_executor:
return execute_futures_dict(
{
key: futures_executor.submit(with_contextvars(self.load), source)
for key, source in sources_by_key.items()
},
)

def to_key(self, source: S) -> str:
"""Converts the source to a key for the collection."""
Expand Down
11 changes: 0 additions & 11 deletions griptape/mixins/futures_executor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import contextlib
from abc import ABC
from concurrent import futures
from typing import Callable
Expand All @@ -17,13 +16,3 @@ class FuturesExecutorMixin(ABC):
futures_executor: futures.Executor = field(
default=Factory(lambda self: self.create_futures_executor(), takes_self=True)
)

def __del__(self) -> None:
# don't raise exceptions in __del__
with contextlib.suppress(Exception):
executor = self.futures_executor

if executor is not None:
self.futures_executor = None # pyright: ignore[reportAttributeAccessIssue] In practice this is safe, nobody will access this attribute after this point

executor.shutdown(wait=True)
27 changes: 14 additions & 13 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,24 @@ def insert_task(
def try_run(self, *args) -> Workflow:
exit_loop = False

while not self.is_finished() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()
with self.create_futures_executor() as futures_executor:
while not self.is_finished() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()

for task in ordered_tasks:
if task.can_run():
future = self.futures_executor.submit(with_contextvars(task.run))
futures_list[future] = task
for task in ordered_tasks:
if task.can_run():
future = futures_executor.submit(with_contextvars(task.run))
futures_list[future] = task

# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True
# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True

break
break

return self
return self

def context(self, task: BaseTask) -> dict[str, Any]:
context = super().context(task)
Expand Down
7 changes: 4 additions & 3 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ def try_run(self) -> BaseArtifact:
return ErrorArtifact("no tool output")

def run_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
return utils.execute_futures_list(
[self.futures_executor.submit(with_contextvars(self.run_action), a) for a in actions]
)
with self.create_futures_executor() as futures_executor:
return utils.execute_futures_list(
[futures_executor.submit(with_contextvars(self.run_action), a) for a in actions]
)

def run_action(self, action: ToolAction) -> tuple[str, BaseArtifact]:
if action.tool is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestBaseEventListenerDriver:
def test_publish_event_no_batched(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=False, futures_executor=executor)
driver = MockEventListenerDriver(batched=False, create_futures_executor=lambda: executor)
mock_event_payload = MockEvent().to_dict()

driver.publish_event(mock_event_payload)
Expand All @@ -18,7 +18,7 @@ def test_publish_event_no_batched(self):
def test_publish_event_yes_batched(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=True, futures_executor=executor)
driver = MockEventListenerDriver(batched=True, create_futures_executor=lambda: executor)
mock_event_payload = MockEvent().to_dict()

# Publish 9 events to fill the batch
Expand All @@ -38,7 +38,7 @@ def test_publish_event_yes_batched(self):
def test_flush_events(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=True, futures_executor=executor)
driver = MockEventListenerDriver(batched=True, create_futures_executor=lambda: executor)
driver.try_publish_event_payload_batch = MagicMock(side_effect=driver.try_publish_event_payload)

driver.flush_events()
Expand Down

0 comments on commit c09fc8b

Please sign in to comment.