Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Parallel Executor #2031

Open
wants to merge 8 commits into
base: fix-streaming
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 67 additions & 151 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import sys
import tqdm
import signal
from concurrent.futures import ThreadPoolExecutor
import logging
import sys
import threading
import traceback
import contextlib

from tqdm.contrib.logging import logging_redirect_tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import tqdm

logger = logging.getLogger(__name__)

@@ -20,6 +19,8 @@ def __init__(
provide_traceback=False,
compare_results=False,
):
assert num_threads > 0

"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""
self.num_threads = num_threads
self.disable_progress_bar = disable_progress_bar
@@ -28,180 +29,95 @@ def __init__(
self.compare_results = compare_results

self.error_count = 0
self.error_lock = threading.Lock()
self.cancel_jobs = threading.Event()
self._lock = threading.Lock()

def execute(self, function, data):
wrapped_function = self._wrap_function(function)
if self.num_threads == 1:
return self._execute_isolated_single_thread(wrapped_function, data)
else:
return self._execute_multi_thread(wrapped_function, data)
exec_type = "multi" if self.num_threads != 1 else "single"
executor = getattr(self, f"_execute_{exec_type}_thread")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would could self._execute_single_thread if self.num_threads == 1 else self._execute_multi_thread?

return executor(wrapped_function, data)

def _wrap_function(self, function):
from dspy.dsp.utils.settings import thread_local_overrides

# Wrap the function with error handling
def wrapped(item):
if self.cancel_jobs.is_set():
return None
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = thread_local_overrides.overrides.copy()
try:
return function(item)
except Exception as e:
with self.error_lock:
self.error_count += 1
current_error_count = self.error_count
if current_error_count >= self.max_errors:
self.cancel_jobs.set()
raise e
if self.provide_traceback:
logger.error(
f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}"
)
logger.error(f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}")
else:
logger.error(
f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace."
)
with self._lock:
self.error_count += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably move this above the logging

if self.error_count >= self.max_errors:
raise e
return None
finally:
thread_local_overrides.overrides = original_overrides

return wrapped

def _execute_isolated_single_thread(self, function, data):
results = []
pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout
)
def _create_pbar(self, data: list):
return tqdm.tqdm(total=len(data), dynamic_ncols=True, disable=self.disable_progress_bar, file=sys.stdout)

from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides

for item in data:
with logging_redirect_tqdm():
if self.cancel_jobs.is_set():
break

# Create an isolated context for each task by copying current overrides
# This way, even if an iteration modifies the overrides, it won't affect subsequent iterations
thread_local_overrides.overrides = original_overrides.copy()

try:
result = function(item)
results.append(result)
finally:
thread_local_overrides.overrides = original_overrides

if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in data if r is not None]),
)
else:
self._update_progress(pbar, len(results), len(data))
def _update_pbar(self, pbar: tqdm.tqdm, nresults, ntotal):
if self.compare_results:
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)", refresh=True)
else:
pbar.set_description(f"Processed {nresults} / {ntotal} examples", refresh=True)

pbar.close()
def _execute_single_thread(self, function, data):
total_score = 0
total_processed = 0

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")
def function_with_progress(item):
result = function(item)

return results
nonlocal total_score, total_processed, pbar
total_processed += 1
if self.compare_results:
if result is not None:
total_score += result[-1]
self._update_pbar(pbar, total_score, total_processed)
else:
self._update_pbar(pbar, total_processed, len(data))

def _update_progress(self, pbar, nresults, ntotal):
if self.compare_results:
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)")
else:
pbar.set_description(f"Processed {nresults} / {ntotal} examples")
return result

pbar.update()
with self._create_pbar(data) as pbar, logging_redirect_tqdm():
return list(map(function_with_progress, data))

def _execute_multi_thread(self, function, data):
results = [None] * len(data) # Pre-allocate results list to maintain order
job_cancelled = "cancelled"

@contextlib.contextmanager
def interrupt_handler_manager():
"""Sets the cancel_jobs event when a SIGINT is received, only in the main thread."""

# TODO: Is this check conducive to nested usage of ParallelExecutor?
if threading.current_thread() is threading.main_thread():
default_handler = signal.getsignal(signal.SIGINT)

def interrupt_handler(sig, frame):
self.cancel_jobs.set()
logger.warning("Received SIGINT. Cancelling execution.")
# Re-raise the signal to allow default behavior
default_handler(sig, frame)

signal.signal(signal.SIGINT, interrupt_handler)
try:
yield
finally:
signal.signal(signal.SIGINT, default_handler)
pbar = self._create_pbar(data)
total_score = 0
total_processed = 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere here, we'd need to have something like:

from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

and then we should pass parent_overrides in data, so that wrapped(item) can handle using the parent thread's overrides, not the new child's overrides.

def function_with_progress(item):
result = function(item)

nonlocal total_score, total_processed, pbar
total_processed += 1
if self.compare_results:
if result is not None:
total_score += result[-1]
self._update_pbar(pbar, total_score, total_processed)
else:
# If not in the main thread, skip setting signal handlers
yield
self._update_pbar(pbar, total_processed, len(data))

def cancellable_function(parent_overrides, index_item):
index, item = index_item
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, this was really important. It seems to have gotten lost in the (otherwise extremely neat) refactor.

When launching multiple threads, we want each thread to inherit the parent thread's local overrides.

if self.cancel_jobs.is_set():
return index, job_cancelled

# Create an isolated context for each task by copying parent's overrides
from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
return result

with ThreadPoolExecutor(max_workers=self.num_threads) as pool:
try:
return index, function(item)
return list(pool.map(function_with_progress, data))
except Exception:
pool.shutdown(wait=False, cancel_futures=True)
raise
finally:
thread_local_overrides.overrides = original_overrides

with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

futures = {}
for pair in enumerate(data):
# Pass the parent thread's overrides to each thread
future = executor.submit(cancellable_function, parent_overrides, pair)
futures[future] = pair

pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout
)

for future in as_completed(futures):
index, result = future.result()

if result is job_cancelled:
continue

results[index] = result

if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in results if r is not None]),
)
else:
self._update_progress(
pbar,
len([r for r in results if r is not None]),
len(data),
)

pbar.close()

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")

return results
pbar.close()
38 changes: 0 additions & 38 deletions tests/evaluate/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -68,44 +68,6 @@ def test_multithread_evaluate_call():
assert score == 100.0


def test_multi_thread_evaluate_call_cancelled(monkeypatch):
# slow LM that sleeps for 1 second before returning the answer
class SlowLM(DummyLM):
def __call__(self, *args, **kwargs):
import time

time.sleep(1)
return super().__call__(*args, **kwargs)

dspy.settings.configure(lm=SlowLM({"What is 1+1?": {"answer": "2"}, "What is 2+2?": {"answer": "4"}}))

devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
program = Predict("question -> answer")
assert program(question="What is 1+1?").answer == "2"

# spawn a thread that will sleep for .1 seconds then send a KeyboardInterrupt
def sleep_then_interrupt():
import time

time.sleep(0.1)
import os

os.kill(os.getpid(), signal.SIGINT)

input_thread = threading.Thread(target=sleep_then_interrupt)
input_thread.start()

with pytest.raises(KeyboardInterrupt):
ev = Evaluate(
devset=devset,
metric=answer_exact_match,
display_progress=False,
num_threads=2,
)
score = ev(program)
assert score == 100.0
Comment on lines -71 to -106
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This effectively got moved to tests/utils/test_parallelizer.py



def test_evaluate_call_bad():
dspy.settings.configure(lm=DummyLM({"What is 1+1?": {"answer": "0"}, "What is 2+2?": {"answer": "0"}}))
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
83 changes: 83 additions & 0 deletions tests/utils/test_parallelizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import time
import pytest

from dspy.utils.parallelizer import ParallelExecutor


def test_single_thread():
data = [0, 1, 2, 3, 4]
executor = ParallelExecutor(num_threads=1)
assert executor.execute(lambda x: x, data) == data


def test_failing_function():
"""
If any item raises an exception, ParallelExecutor should cancel execution
and raise an exception in the main thread without hanging.
"""
data = [0, 1, "boom", 3, 4]

def failing_func(x):
time.sleep(0.01)
if x == "boom":
raise ValueError("Simulated error")
return 42, x

executor = ParallelExecutor(
num_threads=2,
max_errors=1, # Immediately cancel after the first error
provide_traceback=True,
)

with pytest.raises(ValueError, match="Simulated error"):
_ = executor.execute(failing_func, data)


def test_max_errors():
"""
If the number of errors exceeds max_errors, the execution should be cancelled.
"""
data = [0, 1, "boom1", "boom2", "boom3", "boom4", 3, 4]

def failing_func(x):
time.sleep(0.01)
if isinstance(x, str) and x.startswith("boom"):
raise ValueError(f"Simulated error {x}")
return x

executor = ParallelExecutor(
num_threads=2,
max_errors=4,
provide_traceback=True,
)

with pytest.raises(ValueError, match=r"Simulated error boom[3|4]"):
_ = executor.execute(failing_func, data)


def test_sigint_interrupt():
"""
Demonstrate a synthetic Ctrl+C that cancels execution mid-stream.
In practice, you might just press Ctrl+C manually while running pytest.
"""
import signal

data = [0, 1, 2, 3, 4]

def interrupting_func(x):
if x == 2:
time.sleep(0.01)
# Simulate hitting Ctrl+C
signal.raise_signal(signal.SIGINT)
time.sleep(0.01)
return x

executor = ParallelExecutor(
num_threads=2,
max_errors=5,
provide_traceback=True,
)

# We expect a cancellation when 2 is processed
with pytest.raises(KeyboardInterrupt):
_ = executor.execute(interrupting_func, data)
Loading