-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Improve TP>1 Error Handling + Stack Trace #11721
Changes from 34 commits
2d857cd
9e70c5f
f34875c
7a777d9
dfc9dee
c72b45a
4e2dc00
0b4b6af
4c445af
567b424
7d04b98
62e1022
0b0ca08
729938a
0259241
58e4b36
cacf6b0
ccc747d
ddc2a97
af0d529
17e152b
37859d7
c29f329
1c4b92a
eb9b00b
1da99a8
ca7b92d
2743166
8e257c1
b7c50dc
dcfd3b8
6e0e0d4
55a6195
aa6954f
1d15ae0
0347baa
20b8fa2
32840f2
884879a
bb86a03
405bcc1
25e0fea
efd6270
a5a306e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
from multiprocessing.process import BaseProcess | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import psutil | ||
import zmq | ||
|
||
from vllm.config import VllmConfig | ||
|
@@ -38,6 +39,19 @@ def __init__(self, vllm_config: VllmConfig) -> None: | |
# and ensure workers will be terminated. | ||
self._finalizer = weakref.finalize(self, self.shutdown) | ||
|
||
# The child processes will send SIGQUIT when unrecoverable | ||
# errors happen. | ||
def sigquit_handler(signum, frame): | ||
logger.fatal( | ||
"MulitprocExecutor got SIGQUIT from worker processes, shutting " | ||
"down. See stack trace above for root cause issue.") | ||
# Propagate error up to parent process. | ||
parent_process = psutil.Process().parent() | ||
parent_process.send_signal(signal.SIGQUIT) | ||
self.shutdown() | ||
|
||
signal.signal(signal.SIGQUIT, sigquit_handler) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW: why use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mostly because this was inspired by SGL and they use |
||
|
||
self.vllm_config = vllm_config | ||
self.parallel_config = vllm_config.parallel_config | ||
|
||
|
@@ -335,8 +349,11 @@ def signal_handler(signum, frame): | |
except SystemExit: | ||
logger.debug("Worker interrupted.") | ||
|
||
except BaseException as e: | ||
logger.exception(e) | ||
except Exception: | ||
# worker_busy_loop sends exceptions exceptons to Executor | ||
robertgshaw2-redhat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# for shutdown, but if there is an error in startup or an | ||
# error with IPC itself, we need to alert the parent. | ||
psutil.Process().parent().send_signal(signal.SIGQUIT) | ||
raise | ||
|
||
finally: | ||
|
@@ -372,14 +389,16 @@ class ResponseStatus(Enum): | |
|
||
def worker_busy_loop(self): | ||
"""Main busy loop for Multiprocessing Workers""" | ||
|
||
while True: | ||
method, args, kwargs = self.rpc_broadcast_mq.dequeue() | ||
|
||
try: | ||
output = getattr(self.worker, method)(*args, **kwargs) | ||
except BaseException as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
except Exception as e: | ||
self.worker_response_mq.enqueue( | ||
(WorkerProc.ResponseStatus.FAILURE, e)) | ||
logger.exception("WorkerProc hit an exception: %s", exc_info=e) | ||
continue | ||
|
||
self.worker_response_mq.enqueue( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,8 @@ def __init__( | |
distributed_init_method: str, | ||
): | ||
|
||
self.i = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NOTE FOR REVIEWER: this is just a simple POC to show an example. Will remove this before landing. |
||
|
||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) | ||
self.vllm_config = vllm_config | ||
self.model_config = vllm_config.model_config | ||
|
@@ -201,6 +203,10 @@ def execute_model( | |
self, | ||
scheduler_output: "SchedulerOutput", | ||
) -> ModelRunnerOutput: | ||
if self.rank == 0 and self.i == 10: | ||
raise ValueError("ERROR FROM HERE :)") | ||
self.i += 1 | ||
robertgshaw2-redhat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
output = self.model_runner.execute_model(scheduler_output) | ||
return output if self.rank == 0 else None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: moved to
CoreClient
so that it can be shared acrossAsyncLLM
andLLMEngine