You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This RFC describes the approach for supporting pipeline parallelism in vLLM V1 architecture.
Pipeline parallelism was supported in V0 with the virtual-engine approach. In short, we create multiple virtual engines to match the number of pipeline stages, and each virtual engine has its own scheduler, block manager and cache engine, so that they can schedule multiple batches simultaneously to the same executor with pipeline parallelism, saturating all pipeline stages to improve the efficiency. However, virtual engine introduces the following drawbacks:
The lack of a centralized scheduler prevents global optimization from being applied.
It introduces complexity to the engine architecture and implementation.
In this RFC, we aim to support pipeline parallelism in the V1 LLMEngineCore, with the following properties:
Good performance: throughput and TTFT
The design should minimize pipeline bubbles
KV-cache efficiency
The design should minimize KV-cache fragmentation
The design should facilitate KV-cache block reuse across different requests
The design should be compatible with the current prefix caching mechanism
Architecture
The design should align well with V1 architecture
Scheduling policy flexibility
The design should support existing policies (FCFS and priority) and future policies
Proposed Change.
The current V1 engine core runs a busy synchronous loop, and each iteration consists of 3 operations:
schedule(): schedules a batch of requests to run considering new requests, existing requests and preempted requests.
execute(): accepts scheduler output as the execution plan, executes the model, and returns the output.
update(): updates the scheduler state based on finished batch execution output.
In this section, we discuss available options of adopting the V1 engine core architecture to achieve pipeline parallelism.
Option 1: Atomic engine step
Design sketch
Intuitively, it would be ideal to keep the current busy loop mechanism in the engine core, and isolate all pipeline parallelism required changes to the executor, as shown in the above figure.
LLMEngineCore
The busy loop remains the same.
The model output is not corresponding to the scheduler output (i.e., microbatch in the figure) anymore. The model output in this iteration is the output of the microbatch we submitted to the executor PP_SIZE iterations ago.
RayExecutor
microbatch_queue: The queue size is the same as PP_size, and we need to guarantee that the queue is always full. If there is not a sufficient number of microbatches (e.g., cold start or idle), then we need to push empty microbatches (i.e., None) to the queue.
execute() takes one new microbatch, and waits and returns the execution result of the oldest microbatch. Since we guarantee that the queue is always full, we can always get the result of the oldest microbatch immediately (but it may be None).
Pros
The existing busy loop is (largely) unchanged, and all complexity is hidden at the executor level.
We still follow the “(not really) synchronous schedule” paradigm that submits one microbatch and receives the result of a (different) microbatch in the same synchronous function.
Cons
Degraded performance: The oldest (finished) microbatch won’t be fetched unless a new microbatch is scheduled. Although we continuously push empty microbatches when no new requests come in, this may still introduce overheads.
Complexity in managing empty microbatches: To achieve the desired pipeline efficiency, we have to push empty microbatches (None) to the microbatch queue when there are no requests scheduled (e.g., cold start, system idle, etc). Once we fail to maintain a full microbatch queue, the pipeline efficiency cannot be recovered unless we restart the engine.
Option 2 (Recommended): Two-stage engine loop
Since pipeline parallelism enables multiple inflight executions in a pipelined fashion, scheduling and execution become asynchronous by nature: before one microbatch finishes execution, the engine needs to schedule and submit another microbatch. Therefore in this option, execute() is separated into two operations: submission and finish of the microbatch. Specifically, 4 operations are involved:
schedule(): the scheduler considers new requests and scheduelable existing requests and schedules the microbatch
submit(): the engine submits the microbatch to executor for execution
finish(): the executor finishes the execution of microbatch
update(): the scheduler updates its state based on finished microbatch execution output
Design sketch
LLMEngineCore
The busy loop is changed to use an async loop.
The async loop is driven by the following events:
New request comes in
Existing request becomes schedulable
Oldest microbatch finished
The same code can run in synchronous fashion when microbatch_queue size is 1.
Ray Executor
A pipeline executor that executes whatever microbatches it receives.
Pros
Event driven and performant, because the oldest microbatch can finish as soon as possible.
A stepping stone to extend to a fully async scheduler.
Cons
Changes the current synchronous busy loop.
Option 3: Virtual Engine
This is similar to the virtual engine solution in vLLM V0
Pros
Convenient to implement.
Good isolation.
Cons
Needs multiple schedulers, which are hard to manage and maintain.
Cannot reuse KV-cache from a different virtual engine; possible internal fragmentation.
Milestones for Option 2
We have the following milestones for achieving option 2.
Introduce async loop in LLMEngineCore
[In parallel] Support multiple microbatches (disjoint requests)
Implement pipeline-parallel
Optimization: support scheduling the same prefill-stage request in multiple inflight microbatches
Note: for a request in decode stage, it can only be scheduled to one inflight microbatch before we figure out how to deal with speculative decoding and jump decoding; however, for a request in prefill stage, it can be scheduled to multiple inflight microbatches naturally, because prefill for later layers in later PP stage does not depend on the complete finish of the scheduled tokens.
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
Motivation.
This RFC describes the approach for supporting pipeline parallelism in vLLM V1 architecture.
Pipeline parallelism was supported in V0 with the virtual-engine approach. In short, we create multiple virtual engines to match the number of pipeline stages, and each virtual engine has its own scheduler, block manager and cache engine, so that they can schedule multiple batches simultaneously to the same executor with pipeline parallelism, saturating all pipeline stages to improve the efficiency. However, virtual engine introduces the following drawbacks:
In this RFC, we aim to support pipeline parallelism in the V1 LLMEngineCore, with the following properties:
Proposed Change.
The current V1 engine core runs a busy synchronous loop, and each iteration consists of 3 operations:
In this section, we discuss available options of adopting the V1 engine core architecture to achieve pipeline parallelism.
Option 1: Atomic engine step
Design sketch
Intuitively, it would be ideal to keep the current busy loop mechanism in the engine core, and isolate all pipeline parallelism required changes to the executor, as shown in the above figure.
Pros
Cons
Option 2 (Recommended): Two-stage engine loop
Since pipeline parallelism enables multiple inflight executions in a pipelined fashion, scheduling and execution become asynchronous by nature: before one microbatch finishes execution, the engine needs to schedule and submit another microbatch. Therefore in this option, execute() is separated into two operations: submission and finish of the microbatch. Specifically, 4 operations are involved:
Design sketch
LLMEngineCore
Ray Executor
Pros
Cons
Option 3: Virtual Engine
This is similar to the virtual engine solution in vLLM V0
Pros
Cons
Milestones for Option 2
We have the following milestones for achieving option 2.
Feedback Period.
No response
CC List.
@WoosukKwon @robertgshaw2-neuralmagic @tylertitsworth @youkaichao @simon-mo @comaniac @stephanie-wang
Any Other Things.
No response
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: