-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify Parallel Executor #2031
base: fix-streaming
Are you sure you want to change the base?
Changes from 3 commits
3d88999
198c9f3
addbf83
b659685
3a04f5f
ee220b4
3001098
cbd90ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,11 @@ | ||
import sys | ||
import tqdm | ||
import signal | ||
from concurrent.futures import ThreadPoolExecutor | ||
import logging | ||
import sys | ||
import threading | ||
import traceback | ||
import contextlib | ||
|
||
from tqdm.contrib.logging import logging_redirect_tqdm | ||
from concurrent.futures import ThreadPoolExecutor, as_completed | ||
import tqdm | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -20,6 +19,8 @@ def __init__( | |
provide_traceback=False, | ||
compare_results=False, | ||
): | ||
assert num_threads > 0 | ||
|
||
"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1.""" | ||
self.num_threads = num_threads | ||
self.disable_progress_bar = disable_progress_bar | ||
|
@@ -28,180 +29,95 @@ def __init__( | |
self.compare_results = compare_results | ||
|
||
self.error_count = 0 | ||
self.error_lock = threading.Lock() | ||
self.cancel_jobs = threading.Event() | ||
self._lock = threading.Lock() | ||
|
||
def execute(self, function, data): | ||
wrapped_function = self._wrap_function(function) | ||
if self.num_threads == 1: | ||
return self._execute_isolated_single_thread(wrapped_function, data) | ||
else: | ||
return self._execute_multi_thread(wrapped_function, data) | ||
exec_type = "multi" if self.num_threads != 1 else "single" | ||
executor = getattr(self, f"_execute_{exec_type}_thread") | ||
return executor(wrapped_function, data) | ||
|
||
def _wrap_function(self, function): | ||
from dspy.dsp.utils.settings import thread_local_overrides | ||
|
||
# Wrap the function with error handling | ||
def wrapped(item): | ||
if self.cancel_jobs.is_set(): | ||
return None | ||
original_overrides = thread_local_overrides.overrides | ||
thread_local_overrides.overrides = thread_local_overrides.overrides.copy() | ||
try: | ||
return function(item) | ||
except Exception as e: | ||
with self.error_lock: | ||
self.error_count += 1 | ||
current_error_count = self.error_count | ||
if current_error_count >= self.max_errors: | ||
self.cancel_jobs.set() | ||
raise e | ||
if self.provide_traceback: | ||
logger.error( | ||
f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}" | ||
) | ||
logger.error(f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}") | ||
else: | ||
logger.error( | ||
f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace." | ||
) | ||
with self._lock: | ||
self.error_count += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would probably move this above the logging |
||
if self.error_count >= self.max_errors: | ||
raise e | ||
return None | ||
finally: | ||
thread_local_overrides.overrides = original_overrides | ||
|
||
return wrapped | ||
|
||
def _execute_isolated_single_thread(self, function, data): | ||
results = [] | ||
pbar = tqdm.tqdm( | ||
total=len(data), | ||
dynamic_ncols=True, | ||
disable=self.disable_progress_bar, | ||
file=sys.stdout | ||
) | ||
def _create_pbar(self, data: list): | ||
return tqdm.tqdm(total=len(data), dynamic_ncols=True, disable=self.disable_progress_bar, file=sys.stdout) | ||
|
||
from dspy.dsp.utils.settings import thread_local_overrides | ||
original_overrides = thread_local_overrides.overrides | ||
|
||
for item in data: | ||
with logging_redirect_tqdm(): | ||
if self.cancel_jobs.is_set(): | ||
break | ||
|
||
# Create an isolated context for each task by copying current overrides | ||
# This way, even if an iteration modifies the overrides, it won't affect subsequent iterations | ||
thread_local_overrides.overrides = original_overrides.copy() | ||
|
||
try: | ||
result = function(item) | ||
results.append(result) | ||
finally: | ||
thread_local_overrides.overrides = original_overrides | ||
|
||
if self.compare_results: | ||
# Assumes score is the last element of the result tuple | ||
self._update_progress( | ||
pbar, | ||
sum([r[-1] for r in results if r is not None]), | ||
len([r for r in data if r is not None]), | ||
) | ||
else: | ||
self._update_progress(pbar, len(results), len(data)) | ||
def _update_pbar(self, pbar: tqdm.tqdm, nresults, ntotal): | ||
if self.compare_results: | ||
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0 | ||
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)", refresh=True) | ||
else: | ||
pbar.set_description(f"Processed {nresults} / {ntotal} examples", refresh=True) | ||
|
||
pbar.close() | ||
def _execute_single_thread(self, function, data): | ||
total_score = 0 | ||
total_processed = 0 | ||
|
||
if self.cancel_jobs.is_set(): | ||
logger.warning("Execution was cancelled due to errors.") | ||
raise Exception("Execution was cancelled due to errors.") | ||
def function_with_progress(item): | ||
result = function(item) | ||
|
||
return results | ||
nonlocal total_score, total_processed, pbar | ||
total_processed += 1 | ||
if self.compare_results: | ||
if result is not None: | ||
total_score += result[-1] | ||
self._update_pbar(pbar, total_score, total_processed) | ||
else: | ||
self._update_pbar(pbar, total_processed, len(data)) | ||
|
||
def _update_progress(self, pbar, nresults, ntotal): | ||
if self.compare_results: | ||
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0 | ||
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)") | ||
else: | ||
pbar.set_description(f"Processed {nresults} / {ntotal} examples") | ||
return result | ||
|
||
pbar.update() | ||
with self._create_pbar(data) as pbar, logging_redirect_tqdm(): | ||
return list(map(function_with_progress, data)) | ||
|
||
def _execute_multi_thread(self, function, data): | ||
results = [None] * len(data) # Pre-allocate results list to maintain order | ||
job_cancelled = "cancelled" | ||
|
||
@contextlib.contextmanager | ||
def interrupt_handler_manager(): | ||
"""Sets the cancel_jobs event when a SIGINT is received, only in the main thread.""" | ||
|
||
# TODO: Is this check conducive to nested usage of ParallelExecutor? | ||
if threading.current_thread() is threading.main_thread(): | ||
default_handler = signal.getsignal(signal.SIGINT) | ||
|
||
def interrupt_handler(sig, frame): | ||
self.cancel_jobs.set() | ||
logger.warning("Received SIGINT. Cancelling execution.") | ||
# Re-raise the signal to allow default behavior | ||
default_handler(sig, frame) | ||
|
||
signal.signal(signal.SIGINT, interrupt_handler) | ||
try: | ||
yield | ||
finally: | ||
signal.signal(signal.SIGINT, default_handler) | ||
pbar = self._create_pbar(data) | ||
total_score = 0 | ||
total_processed = 0 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhere here, we'd need to have something like: from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy() and then we should pass |
||
def function_with_progress(item): | ||
result = function(item) | ||
|
||
nonlocal total_score, total_processed, pbar | ||
total_processed += 1 | ||
if self.compare_results: | ||
if result is not None: | ||
total_score += result[-1] | ||
self._update_pbar(pbar, total_score, total_processed) | ||
else: | ||
# If not in the main thread, skip setting signal handlers | ||
yield | ||
self._update_pbar(pbar, total_processed, len(data)) | ||
|
||
def cancellable_function(parent_overrides, index_item): | ||
index, item = index_item | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I recall correctly, this was really important. It seems to have gotten lost in the (otherwise extremely neat) refactor. When launching multiple threads, we want each thread to inherit the parent thread's local overrides. |
||
if self.cancel_jobs.is_set(): | ||
return index, job_cancelled | ||
|
||
# Create an isolated context for each task by copying parent's overrides | ||
from dspy.dsp.utils.settings import thread_local_overrides | ||
original_overrides = thread_local_overrides.overrides | ||
thread_local_overrides.overrides = parent_overrides.copy() | ||
return result | ||
|
||
with ThreadPoolExecutor(max_workers=self.num_threads) as pool: | ||
try: | ||
return index, function(item) | ||
return list(pool.map(function_with_progress, data)) | ||
except Exception: | ||
pool.shutdown(wait=False, cancel_futures=True) | ||
raise | ||
finally: | ||
thread_local_overrides.overrides = original_overrides | ||
|
||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager(): | ||
from dspy.dsp.utils.settings import thread_local_overrides | ||
parent_overrides = thread_local_overrides.overrides.copy() | ||
|
||
futures = {} | ||
for pair in enumerate(data): | ||
# Pass the parent thread's overrides to each thread | ||
future = executor.submit(cancellable_function, parent_overrides, pair) | ||
futures[future] = pair | ||
|
||
pbar = tqdm.tqdm( | ||
total=len(data), | ||
dynamic_ncols=True, | ||
disable=self.disable_progress_bar, | ||
file=sys.stdout | ||
) | ||
|
||
for future in as_completed(futures): | ||
index, result = future.result() | ||
|
||
if result is job_cancelled: | ||
continue | ||
|
||
results[index] = result | ||
|
||
if self.compare_results: | ||
# Assumes score is the last element of the result tuple | ||
self._update_progress( | ||
pbar, | ||
sum([r[-1] for r in results if r is not None]), | ||
len([r for r in results if r is not None]), | ||
) | ||
else: | ||
self._update_progress( | ||
pbar, | ||
len([r for r in results if r is not None]), | ||
len(data), | ||
) | ||
|
||
pbar.close() | ||
|
||
if self.cancel_jobs.is_set(): | ||
logger.warning("Execution was cancelled due to errors.") | ||
raise Exception("Execution was cancelled due to errors.") | ||
|
||
return results | ||
pbar.close() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,44 +68,6 @@ def test_multithread_evaluate_call(): | |
assert score == 100.0 | ||
|
||
|
||
def test_multi_thread_evaluate_call_cancelled(monkeypatch): | ||
# slow LM that sleeps for 1 second before returning the answer | ||
class SlowLM(DummyLM): | ||
def __call__(self, *args, **kwargs): | ||
import time | ||
|
||
time.sleep(1) | ||
return super().__call__(*args, **kwargs) | ||
|
||
dspy.settings.configure(lm=SlowLM({"What is 1+1?": {"answer": "2"}, "What is 2+2?": {"answer": "4"}})) | ||
|
||
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")] | ||
program = Predict("question -> answer") | ||
assert program(question="What is 1+1?").answer == "2" | ||
|
||
# spawn a thread that will sleep for .1 seconds then send a KeyboardInterrupt | ||
def sleep_then_interrupt(): | ||
import time | ||
|
||
time.sleep(0.1) | ||
import os | ||
|
||
os.kill(os.getpid(), signal.SIGINT) | ||
|
||
input_thread = threading.Thread(target=sleep_then_interrupt) | ||
input_thread.start() | ||
|
||
with pytest.raises(KeyboardInterrupt): | ||
ev = Evaluate( | ||
devset=devset, | ||
metric=answer_exact_match, | ||
display_progress=False, | ||
num_threads=2, | ||
) | ||
score = ev(program) | ||
assert score == 100.0 | ||
Comment on lines
-71
to
-106
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This effectively got moved to |
||
|
||
|
||
def test_evaluate_call_bad(): | ||
dspy.settings.configure(lm=DummyLM({"What is 1+1?": {"answer": "0"}, "What is 2+2?": {"answer": "0"}})) | ||
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import time | ||
import pytest | ||
|
||
from dspy.utils.parallelizer import ParallelExecutor | ||
|
||
|
||
def test_single_thread(): | ||
data = [0, 1, 2, 3, 4] | ||
executor = ParallelExecutor(num_threads=1) | ||
assert executor.execute(lambda x: x, data) == data | ||
|
||
|
||
def test_failing_function(): | ||
""" | ||
If any item raises an exception, ParallelExecutor should cancel execution | ||
and raise an exception in the main thread without hanging. | ||
""" | ||
data = [0, 1, "boom", 3, 4] | ||
|
||
def failing_func(x): | ||
time.sleep(0.01) | ||
if x == "boom": | ||
raise ValueError("Simulated error") | ||
return 42, x | ||
|
||
executor = ParallelExecutor( | ||
num_threads=2, | ||
max_errors=1, # Immediately cancel after the first error | ||
provide_traceback=True, | ||
) | ||
|
||
with pytest.raises(ValueError, match="Simulated error"): | ||
_ = executor.execute(failing_func, data) | ||
|
||
|
||
def test_max_errors(): | ||
""" | ||
If the number of errors exceeds max_errors, the execution should be cancelled. | ||
""" | ||
data = [0, 1, "boom1", "boom2", "boom3", "boom4", 3, 4] | ||
|
||
def failing_func(x): | ||
time.sleep(0.01) | ||
if isinstance(x, str) and x.startswith("boom"): | ||
raise ValueError(f"Simulated error {x}") | ||
return x | ||
|
||
executor = ParallelExecutor( | ||
num_threads=2, | ||
max_errors=4, | ||
provide_traceback=True, | ||
) | ||
|
||
with pytest.raises(ValueError, match=r"Simulated error boom[3|4]"): | ||
_ = executor.execute(failing_func, data) | ||
|
||
|
||
def test_sigint_interrupt(): | ||
""" | ||
Demonstrate a synthetic Ctrl+C that cancels execution mid-stream. | ||
In practice, you might just press Ctrl+C manually while running pytest. | ||
""" | ||
import signal | ||
|
||
data = [0, 1, 2, 3, 4] | ||
|
||
def interrupting_func(x): | ||
if x == 2: | ||
time.sleep(0.01) | ||
# Simulate hitting Ctrl+C | ||
signal.raise_signal(signal.SIGINT) | ||
time.sleep(0.01) | ||
return x | ||
|
||
executor = ParallelExecutor( | ||
num_threads=2, | ||
max_errors=5, | ||
provide_traceback=True, | ||
) | ||
|
||
# We expect a cancellation when 2 is processed | ||
with pytest.raises(KeyboardInterrupt): | ||
_ = executor.execute(interrupting_func, data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would could
self._execute_single_thread if self.num_threads == 1 else self._execute_multi_thread
?