diff --git a/vllm/config.py b/vllm/config.py index 643698f8bbec3..5e11c02541861 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2560,14 +2560,6 @@ def from_cli(cls, cli_value: str) -> "KVTransferConfig": return KVTransferConfig.model_validate_json(cli_value) def model_post_init(self, __context: Any) -> None: - supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"] - if all([ - self.kv_connector is not None, self.kv_connector - not in supported_kv_connector - ]): - raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " - f"Supported connectors are " - f"{supported_kv_connector}.") if self.kv_role is not None and self.kv_role not in [ "kv_producer", "kv_consumer", "kv_both" diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 3e2bb436d24b5..6372dab726086 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING +import importlib +from typing import TYPE_CHECKING, Callable, Dict, Type from .base import KVConnectorBase @@ -7,14 +8,41 @@ class KVConnectorFactory: + _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} - @staticmethod - def create_connector(rank: int, local_rank: int, + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> Type[KVConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector(cls, rank: int, local_rank: int, config: "VllmConfig") -> KVConnectorBase: - supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"] - if config.kv_transfer_config.kv_connector in supported_kv_connector: - from .simple_connector import SimpleConnector - return SimpleConnector(rank, local_rank, config) - else: - raise ValueError(f"Unsupported connector type: " - f"{config.kv_connector}") + connector_name = config.kv_transfer_config.kv_connector + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + return connector_cls(rank, local_rank, config) + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. +KVConnectorFactory.register_connector( + "PyNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector")