Skip to content

Commit

Permalink
[Misc] KV cache transfer connector registry (#11481)
Browse files Browse the repository at this point in the history
Signed-off-by: KuntaiDu <[email protected]>
  • Loading branch information
KuntaiDu authored Dec 29, 2024
1 parent dba4d9d commit faef77c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
8 changes: 0 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2559,14 +2559,6 @@ def from_cli(cls, cli_value: str) -> "KVTransferConfig":
return KVTransferConfig.model_validate_json(cli_value)

def model_post_init(self, __context: Any) -> None:
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if all([
self.kv_connector is not None, self.kv_connector
not in supported_kv_connector
]):
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
f"Supported connectors are "
f"{supported_kv_connector}.")

if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"
Expand Down
48 changes: 38 additions & 10 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
import importlib
from typing import TYPE_CHECKING, Callable, Dict, Type

from .base import KVConnectorBase

Expand All @@ -7,14 +8,41 @@


class KVConnectorFactory:
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}

@staticmethod
def create_connector(rank: int, local_rank: int,
@classmethod
def register_connector(cls, name: str, module_path: str,
class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")

def loader() -> Type[KVConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)

cls._registry[name] = loader

@classmethod
def create_connector(cls, rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if config.kv_transfer_config.kv_connector in supported_kv_connector:
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
else:
raise ValueError(f"Unsupported connector type: "
f"{config.kv_connector}")
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")

connector_cls = cls._registry[connector_name]()
return connector_cls(rank, local_rank, config)


# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
KVConnectorFactory.register_connector(
"PyNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")

KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")

0 comments on commit faef77c

Please sign in to comment.