-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify Parallel Executor #2031
base: fix-streaming
Are you sure you want to change the base?
Changes from all commits
3d88999
198c9f3
addbf83
b659685
3a04f5f
ee220b4
3001098
cbd90ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,11 @@ | ||
import sys | ||
import tqdm | ||
import signal | ||
from concurrent.futures import ThreadPoolExecutor | ||
import logging | ||
import sys | ||
import threading | ||
import traceback | ||
import contextlib | ||
|
||
from tqdm.contrib.logging import logging_redirect_tqdm | ||
from concurrent.futures import ThreadPoolExecutor, as_completed | ||
from tqdm import tqdm | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -20,6 +19,8 @@ def __init__( | |
provide_traceback=False, | ||
compare_results=False, | ||
): | ||
assert num_threads > 0 | ||
|
||
"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1.""" | ||
self.num_threads = num_threads | ||
self.disable_progress_bar = disable_progress_bar | ||
|
@@ -28,180 +29,102 @@ def __init__( | |
self.compare_results = compare_results | ||
|
||
self.error_count = 0 | ||
self.error_lock = threading.Lock() | ||
self.cancel_jobs = threading.Event() | ||
self._lock = threading.Lock() | ||
|
||
def execute(self, function, data): | ||
wrapped_function = self._wrap_function(function) | ||
if self.num_threads == 1: | ||
return self._execute_isolated_single_thread(wrapped_function, data) | ||
else: | ||
return self._execute_multi_thread(wrapped_function, data) | ||
exec_type = "multi" if self.num_threads != 1 else "single" | ||
executor = getattr(self, f"_execute_{exec_type}_thread") | ||
return executor(wrapped_function, data) | ||
|
||
def _wrap_function(self, function): | ||
from dspy.dsp.utils.settings import thread_local_overrides | ||
|
||
# Wrap the function with error handling | ||
def wrapped(item): | ||
if self.cancel_jobs.is_set(): | ||
return None | ||
original_overrides = thread_local_overrides.overrides | ||
thread_local_overrides.overrides = thread_local_overrides.overrides.copy() | ||
try: | ||
return function(item) | ||
except Exception as e: | ||
with self.error_lock: | ||
self.error_count += 1 | ||
current_error_count = self.error_count | ||
if current_error_count >= self.max_errors: | ||
self.cancel_jobs.set() | ||
raise e | ||
if self.provide_traceback: | ||
logger.error( | ||
f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}" | ||
) | ||
logger.error(f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}") | ||
else: | ||
logger.error( | ||
f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace." | ||
) | ||
with self._lock: | ||
self.error_count += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would probably move this above the logging |
||
if self.error_count >= self.max_errors: | ||
raise e | ||
return None | ||
finally: | ||
thread_local_overrides.overrides = original_overrides | ||
|
||
return wrapped | ||
|
||
def _execute_isolated_single_thread(self, function, data): | ||
results = [] | ||
pbar = tqdm.tqdm( | ||
total=len(data), | ||
dynamic_ncols=True, | ||
disable=self.disable_progress_bar, | ||
file=sys.stdout | ||
) | ||
def _create_pbar(self, data: list): | ||
return tqdm(total=len(data), dynamic_ncols=True, disable=self.disable_progress_bar, file=sys.stdout) | ||
|
||
from dspy.dsp.utils.settings import thread_local_overrides | ||
original_overrides = thread_local_overrides.overrides | ||
|
||
for item in data: | ||
with logging_redirect_tqdm(): | ||
if self.cancel_jobs.is_set(): | ||
break | ||
def _update_pbar(self, pbar: tqdm, nresults, ntotal): | ||
if self.compare_results: | ||
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0 | ||
description = f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)" | ||
else: | ||
description = f"Processed {nresults} / {ntotal} examples" | ||
pbar.set_description(description, refresh=True) | ||
|
||
# Create an isolated context for each task by copying current overrides | ||
# This way, even if an iteration modifies the overrides, it won't affect subsequent iterations | ||
thread_local_overrides.overrides = original_overrides.copy() | ||
def _execute_single_thread(self, function, data): | ||
pbar = self._create_pbar(data) | ||
total_score = 0 | ||
total_processed = 0 | ||
|
||
try: | ||
result = function(item) | ||
results.append(result) | ||
finally: | ||
thread_local_overrides.overrides = original_overrides | ||
def function_with_progress(item): | ||
result = function(item) | ||
|
||
with self._lock: | ||
nonlocal total_score, total_processed, pbar | ||
total_processed += 1 | ||
if self.compare_results: | ||
# Assumes score is the last element of the result tuple | ||
self._update_progress( | ||
pbar, | ||
sum([r[-1] for r in results if r is not None]), | ||
len([r for r in data if r is not None]), | ||
) | ||
if result is not None: | ||
total_score += result[-1] | ||
self._update_pbar(pbar, total_score, total_processed) | ||
else: | ||
self._update_progress(pbar, len(results), len(data)) | ||
self._update_pbar(pbar, total_processed, len(data)) | ||
|
||
pbar.close() | ||
|
||
if self.cancel_jobs.is_set(): | ||
logger.warning("Execution was cancelled due to errors.") | ||
raise Exception("Execution was cancelled due to errors.") | ||
|
||
return results | ||
|
||
def _update_progress(self, pbar, nresults, ntotal): | ||
if self.compare_results: | ||
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0 | ||
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)") | ||
else: | ||
pbar.set_description(f"Processed {nresults} / {ntotal} examples") | ||
|
||
pbar.update() | ||
|
||
def _execute_multi_thread(self, function, data): | ||
results = [None] * len(data) # Pre-allocate results list to maintain order | ||
job_cancelled = "cancelled" | ||
|
||
@contextlib.contextmanager | ||
def interrupt_handler_manager(): | ||
"""Sets the cancel_jobs event when a SIGINT is received, only in the main thread.""" | ||
|
||
# TODO: Is this check conducive to nested usage of ParallelExecutor? | ||
if threading.current_thread() is threading.main_thread(): | ||
default_handler = signal.getsignal(signal.SIGINT) | ||
|
||
def interrupt_handler(sig, frame): | ||
self.cancel_jobs.set() | ||
logger.warning("Received SIGINT. Cancelling execution.") | ||
# Re-raise the signal to allow default behavior | ||
default_handler(sig, frame) | ||
|
||
signal.signal(signal.SIGINT, interrupt_handler) | ||
try: | ||
yield | ||
finally: | ||
signal.signal(signal.SIGINT, default_handler) | ||
else: | ||
# If not in the main thread, skip setting signal handlers | ||
yield | ||
|
||
def cancellable_function(parent_overrides, index_item): | ||
index, item = index_item | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I recall correctly, this was really important. It seems to have gotten lost in the (otherwise extremely neat) refactor. When launching multiple threads, we want each thread to inherit the parent thread's local overrides. |
||
if self.cancel_jobs.is_set(): | ||
return index, job_cancelled | ||
|
||
# Create an isolated context for each task by copying parent's overrides | ||
from dspy.dsp.utils.settings import thread_local_overrides | ||
original_overrides = thread_local_overrides.overrides | ||
thread_local_overrides.overrides = parent_overrides.copy() | ||
return result | ||
|
||
with logging_redirect_tqdm(): | ||
try: | ||
return index, function(item) | ||
return list(map(function_with_progress, data)) | ||
finally: | ||
thread_local_overrides.overrides = original_overrides | ||
pbar.close() | ||
|
||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager(): | ||
from dspy.dsp.utils.settings import thread_local_overrides | ||
parent_overrides = thread_local_overrides.overrides.copy() | ||
|
||
futures = {} | ||
for pair in enumerate(data): | ||
# Pass the parent thread's overrides to each thread | ||
future = executor.submit(cancellable_function, parent_overrides, pair) | ||
futures[future] = pair | ||
|
||
pbar = tqdm.tqdm( | ||
total=len(data), | ||
dynamic_ncols=True, | ||
disable=self.disable_progress_bar, | ||
file=sys.stdout | ||
) | ||
|
||
for future in as_completed(futures): | ||
index, result = future.result() | ||
|
||
if result is job_cancelled: | ||
continue | ||
def _execute_multi_thread(self, function, data): | ||
pbar = self._create_pbar(data) | ||
total_score = 0 | ||
total_processed = 0 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhere here, we'd need to have something like: from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy() and then we should pass |
||
results[index] = result | ||
def function_with_progress(item): | ||
result = function(item) | ||
|
||
with self._lock: | ||
nonlocal total_score, total_processed, pbar | ||
total_processed += 1 | ||
if self.compare_results: | ||
# Assumes score is the last element of the result tuple | ||
self._update_progress( | ||
pbar, | ||
sum([r[-1] for r in results if r is not None]), | ||
len([r for r in results if r is not None]), | ||
) | ||
if result is not None: | ||
total_score += result[-1] | ||
self._update_pbar(pbar, total_score, total_processed) | ||
else: | ||
self._update_progress( | ||
pbar, | ||
len([r for r in results if r is not None]), | ||
len(data), | ||
) | ||
|
||
pbar.close() | ||
self._update_pbar(pbar, total_processed, len(data)) | ||
|
||
if self.cancel_jobs.is_set(): | ||
logger.warning("Execution was cancelled due to errors.") | ||
raise Exception("Execution was cancelled due to errors.") | ||
return result | ||
|
||
return results | ||
with ThreadPoolExecutor(max_workers=self.num_threads) as pool: | ||
try: | ||
return list(pool.map(function_with_progress, data)) | ||
except Exception: | ||
pool.shutdown(wait=False, cancel_futures=True) | ||
raise | ||
finally: | ||
pbar.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would could
self._execute_single_thread if self.num_threads == 1 else self._execute_multi_thread
?