Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Parallel Executor #2031

Open
wants to merge 8 commits into
base: fix-streaming
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ on:
types: [opened, synchronize, reopened]

env:
POETRY_VERSION: "1.7.1"
POETRY_VERSION: "2.0.0"

jobs:
fix:
217 changes: 70 additions & 147 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import sys
import tqdm
import signal
from concurrent.futures import ThreadPoolExecutor
import logging
import sys
import threading
import traceback
import contextlib

from tqdm.contrib.logging import logging_redirect_tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

logger = logging.getLogger(__name__)

@@ -20,6 +19,8 @@ def __init__(
provide_traceback=False,
compare_results=False,
):
assert num_threads > 0

"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""
self.num_threads = num_threads
self.disable_progress_bar = disable_progress_bar
@@ -28,180 +29,102 @@ def __init__(
self.compare_results = compare_results

self.error_count = 0
self.error_lock = threading.Lock()
self.cancel_jobs = threading.Event()
self._lock = threading.Lock()

def execute(self, function, data):
wrapped_function = self._wrap_function(function)
if self.num_threads == 1:
return self._execute_isolated_single_thread(wrapped_function, data)
else:
return self._execute_multi_thread(wrapped_function, data)
exec_type = "multi" if self.num_threads != 1 else "single"
executor = getattr(self, f"_execute_{exec_type}_thread")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would could self._execute_single_thread if self.num_threads == 1 else self._execute_multi_thread?

return executor(wrapped_function, data)

def _wrap_function(self, function):
from dspy.dsp.utils.settings import thread_local_overrides

# Wrap the function with error handling
def wrapped(item):
if self.cancel_jobs.is_set():
return None
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = thread_local_overrides.overrides.copy()
try:
return function(item)
except Exception as e:
with self.error_lock:
self.error_count += 1
current_error_count = self.error_count
if current_error_count >= self.max_errors:
self.cancel_jobs.set()
raise e
if self.provide_traceback:
logger.error(
f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}"
)
logger.error(f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}")
else:
logger.error(
f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace."
)
with self._lock:
self.error_count += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably move this above the logging

if self.error_count >= self.max_errors:
raise e
return None
finally:
thread_local_overrides.overrides = original_overrides

return wrapped

def _execute_isolated_single_thread(self, function, data):
results = []
pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout
)
def _create_pbar(self, data: list):
return tqdm(total=len(data), dynamic_ncols=True, disable=self.disable_progress_bar, file=sys.stdout)

from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides

for item in data:
with logging_redirect_tqdm():
if self.cancel_jobs.is_set():
break
def _update_pbar(self, pbar: tqdm, nresults, ntotal):
if self.compare_results:
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0
description = f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)"
else:
description = f"Processed {nresults} / {ntotal} examples"
pbar.set_description(description, refresh=True)

# Create an isolated context for each task by copying current overrides
# This way, even if an iteration modifies the overrides, it won't affect subsequent iterations
thread_local_overrides.overrides = original_overrides.copy()
def _execute_single_thread(self, function, data):
pbar = self._create_pbar(data)
total_score = 0
total_processed = 0

try:
result = function(item)
results.append(result)
finally:
thread_local_overrides.overrides = original_overrides
def function_with_progress(item):
result = function(item)

with self._lock:
nonlocal total_score, total_processed, pbar
total_processed += 1
if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in data if r is not None]),
)
if result is not None:
total_score += result[-1]
self._update_pbar(pbar, total_score, total_processed)
else:
self._update_progress(pbar, len(results), len(data))
self._update_pbar(pbar, total_processed, len(data))

pbar.close()

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")

return results

def _update_progress(self, pbar, nresults, ntotal):
if self.compare_results:
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)")
else:
pbar.set_description(f"Processed {nresults} / {ntotal} examples")

pbar.update()

def _execute_multi_thread(self, function, data):
results = [None] * len(data) # Pre-allocate results list to maintain order
job_cancelled = "cancelled"

@contextlib.contextmanager
def interrupt_handler_manager():
"""Sets the cancel_jobs event when a SIGINT is received, only in the main thread."""

# TODO: Is this check conducive to nested usage of ParallelExecutor?
if threading.current_thread() is threading.main_thread():
default_handler = signal.getsignal(signal.SIGINT)

def interrupt_handler(sig, frame):
self.cancel_jobs.set()
logger.warning("Received SIGINT. Cancelling execution.")
# Re-raise the signal to allow default behavior
default_handler(sig, frame)

signal.signal(signal.SIGINT, interrupt_handler)
try:
yield
finally:
signal.signal(signal.SIGINT, default_handler)
else:
# If not in the main thread, skip setting signal handlers
yield

def cancellable_function(parent_overrides, index_item):
index, item = index_item
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, this was really important. It seems to have gotten lost in the (otherwise extremely neat) refactor.

When launching multiple threads, we want each thread to inherit the parent thread's local overrides.

if self.cancel_jobs.is_set():
return index, job_cancelled

# Create an isolated context for each task by copying parent's overrides
from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
return result

with logging_redirect_tqdm():
try:
return index, function(item)
return list(map(function_with_progress, data))
finally:
thread_local_overrides.overrides = original_overrides
pbar.close()

with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

futures = {}
for pair in enumerate(data):
# Pass the parent thread's overrides to each thread
future = executor.submit(cancellable_function, parent_overrides, pair)
futures[future] = pair

pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout
)

for future in as_completed(futures):
index, result = future.result()

if result is job_cancelled:
continue
def _execute_multi_thread(self, function, data):
pbar = self._create_pbar(data)
total_score = 0
total_processed = 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere here, we'd need to have something like:

from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

and then we should pass parent_overrides in data, so that wrapped(item) can handle using the parent thread's overrides, not the new child's overrides.

results[index] = result
def function_with_progress(item):
result = function(item)

with self._lock:
nonlocal total_score, total_processed, pbar
total_processed += 1
if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in results if r is not None]),
)
if result is not None:
total_score += result[-1]
self._update_pbar(pbar, total_score, total_processed)
else:
self._update_progress(
pbar,
len([r for r in results if r is not None]),
len(data),
)

pbar.close()
self._update_pbar(pbar, total_processed, len(data))

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")
return result

return results
with ThreadPoolExecutor(max_workers=self.num_threads) as pool:
try:
return list(pool.map(function_with_progress, data))
except Exception:
pool.shutdown(wait=False, cancel_futures=True)
raise
finally:
pbar.close()
Loading