Skip to content

Commit

Permalink
[Core] Support disaggregated prefill with Mooncake Transfer Engine (#…
Browse files Browse the repository at this point in the history
…10884)

Signed-off-by: Shangming Cai <[email protected]>
  • Loading branch information
ShangmingCai authored Dec 15, 2024
1 parent 38e599d commit d263bd9
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 31 deletions.
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,13 +2171,14 @@ def from_cli(cls, cli_value: str) -> "KVTransferConfig":
return KVTransferConfig.model_validate_json(cli_value)

def model_post_init(self, __context: Any) -> None:
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if all([
self.kv_connector is not None,
self.kv_connector != "PyNcclConnector"
self.kv_connector is not None, self.kv_connector
not in supported_kv_connector
]):
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
f"Supported connectors are "
f"`PyNcclConnector`.")
f"{supported_kv_connector}.")

if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class KVConnectorFactory:
@staticmethod
def create_connector(rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
if config.kv_transfer_config.kv_connector == 'PyNcclConnector':
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if config.kv_transfer_config.kv_connector in supported_kv_connector:
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
else:
Expand Down
101 changes: 74 additions & 27 deletions vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe.
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
Expand All @@ -15,7 +16,6 @@
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors

Expand All @@ -36,32 +36,66 @@ def __init__(

self.config = config.kv_transfer_config

logger.info("Initializing PyNcclConfig under kv_transfer_config %s",
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
PyNcclPipe)
logger.info(
"Initializing PyNcclConfig under kv_transfer_config %s",
self.config)
elif self.config.kv_connector == "MooncakeConnector":
# Check if MOONCAKE_CONFIG_PATH is set
import os
use_mooncake_distributed_pipe = os.getenv(
'MOONCAKE_CONFIG_PATH') is not None

if not use_mooncake_distributed_pipe:
raise ValueError(
"To use MooncakeConnector, you need to pass the ENV: "
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
else:
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
MooncakePipe)
logger.info(
"Initializing MooncakeConfig under kv_transfer_config %s",
self.config)

self.lookup_buffer_size = self.config.kv_buffer_size

self.producer_buffer: Optional[SimpleBuffer] = None
self.consumer_buffer: Optional[SimpleBuffer] = None

self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]

# 2 pipes for every rank in the world
port_offset_base = 2 * rank

# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:

self.producer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
if self.config.kv_connector == "PyNcclConnector":
self.producer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
elif self.config.kv_connector == "MooncakeConnector":
self.producer_data_pipe = MooncakePipe(
local_rank=local_rank,
config=self.config,
)
# We only need to initialize MooncakePipe once
self.producer_signal_pipe = self.producer_data_pipe

self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
self.producer_data_pipe,
self.config.kv_buffer_size)
Expand All @@ -70,17 +104,25 @@ def __init__(

# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producder
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
elif self.config.kv_connector == "MooncakeConnector":
self.consumer_data_pipe = MooncakePipe(
local_rank=local_rank,
config=self.config,
)
self.consumer_signal_pipe = self.consumer_data_pipe

self.consumer_buffer = SimpleBuffer(
self.consumer_signal_pipe,
self.consumer_data_pipe,
Expand Down Expand Up @@ -260,6 +302,11 @@ def recv_kv_caches_and_hidden_states(

def close(self):
self.producer_data_pipe.close()
self.producer_signal_pipe.close()
self.consumer_data_pipe.close()
self.consumer_signal_pipe.close()
if self.config.kv_connector == "PyNcclConnector":
self.producer_signal_pipe.close()
self.consumer_signal_pipe.close()
elif self.config.kv_connector == "MooncakeConnector":
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass
Loading

0 comments on commit d263bd9

Please sign in to comment.