diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b563c96343f92..b28296d603642 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -117,6 +117,16 @@ steps: - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests +- label: Platform Plugin Test + working_dir: "/vllm-workspace/tests" + fast_check: true + mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + commands: + - pip install -e ./plugins/vllm_add_dummy_platform + - pytest -v -s platform/test_platform_plugin.py + - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 diff --git a/docs/source/design/plugin_system.md b/docs/source/design/plugin_system.md index 79aff757518f2..6a2b5258e07d1 100644 --- a/docs/source/design/plugin_system.md +++ b/docs/source/design/plugin_system.md @@ -43,7 +43,7 @@ Every plugin has three parts: ## What Can Plugins Do? -Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM. +Currently, the primary use case for plugins is to register custom, out-of-the-tree models or platforms into vLLM. This is done by calling `ModelRegistry.register_model` or `PlatformRegistry.register_platform` to register the model or platform. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM. ## Guidelines for Writing Plugins diff --git a/tests/platform/test_platform_plugin.py b/tests/platform/test_platform_plugin.py new file mode 100644 index 0000000000000..a49309e7c08e7 --- /dev/null +++ b/tests/platform/test_platform_plugin.py @@ -0,0 +1,13 @@ +import pytest + +from vllm_add_dummy_platform.my_platform import DummyPlatform + +from vllm.platforms import PlatformRegistry, current_platform + + +def test_current_platform_register(): + # make sure the platform is registered + assert PlatformRegistry.current_platform == "my_platform" + # make sure the platform is loaded + assert current_platform.device_name == "dummy" + assert type(current_platform.platform) == DummyPlatform diff --git a/tests/plugins/vllm_add_dummy_platform/setup.py b/tests/plugins/vllm_add_dummy_platform/setup.py new file mode 100644 index 0000000000000..0b73d173040ab --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup(name='vllm_add_dummy_platform', + version='0.1', + packages=['vllm_add_dummy_platform'], + entry_points={ + 'vllm.general_plugins': + ["register_dummy_model = vllm_add_dummy_platform:register"] + }) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py new file mode 100644 index 0000000000000..8435c365446d7 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py @@ -0,0 +1,9 @@ +from vllm import PlatformRegistry + + +def register(): + # Register the dummy platform + PlatformRegistry.register_platform( + "my_platform", "vllm_add_dummy_platform.my_platform.DummyPlatform") + # Set the current platform to the dummy platform + PlatformRegistry.set_current_platform("my_platform") diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_attention.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_attention.py new file mode 100644 index 0000000000000..8c0df08fa29b8 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_attention.py @@ -0,0 +1,13 @@ +class DummyAttentionImpl: + + def forward(self): + pass + + +class DummyAttentionBackend: + + def __init__(self): + pass + + def get_impl_cls(self): + return DummyAttentionImpl diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_model_runner.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_model_runner.py new file mode 100644 index 0000000000000..1d9060b6b7e68 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_model_runner.py @@ -0,0 +1,7 @@ +from my_attention import DummyAttentionBackend + + +class DummyModelRunner: + + def __init__(self): + self.attn_backend = DummyAttentionBackend() diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py new file mode 100644 index 0000000000000..86a3d9b6eeab6 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_platform.py @@ -0,0 +1,20 @@ +from vllm.config import VllmConfig +from vllm.platforms import Platform, PlatformEnum + + +class DummyPlatform(Platform): + _enum = PlatformEnum.UNSPECIFIED + device_name = "dummy" + + def __init__(self): + super().__init__() + + @classmethod + def get_device_name(cls) -> str: + return "dummy" + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + parallel_config.worker_cls = \ + "vllm_add_dummy_platform.my_worker.DummyWorker" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_worker.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_worker.py new file mode 100644 index 0000000000000..a144df2438b20 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/my_worker.py @@ -0,0 +1,14 @@ +from typing import List + +from my_model_runner import DummyModelRunner + + +class DummyCacheEngine: + pass + + +class DummyWorker: + + def __init__(self): + self.cache_engine = List[DummyCacheEngine] + self.model_runner = DummyModelRunner() diff --git a/vllm/__init__.py b/vllm/__init__.py index 45252b93e3d54..fa91a8947b95e 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -12,16 +12,23 @@ EmbeddingRequestOutput, PoolingOutput, PoolingRequestOutput, RequestOutput, ScoringOutput, ScoringRequestOutput) +from vllm.platforms.registry import PlatformRegistry +from vllm.plugins import load_general_plugins from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from .version import __version__, __version_tuple__ +# Load general plugins first when the module is imported to make sure that all +# necessary global variables are set. Such as the `current_platform`. +load_general_plugins() + __all__ = [ "__version__", "__version_tuple__", "LLM", "ModelRegistry", + "PlatformRegistry", "PromptType", "TextPrompt", "TokensPrompt", diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 05d997279893b..54ba60229325e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -263,15 +263,6 @@ def unified_attention_fake( return torch.empty_like(query).contiguous() -direct_register_custom_op( - op_name="unified_attention", - op_func=unified_attention, - mutates_args=["kv_cache"], - fake_impl=unified_attention_fake, - dispatch_key=current_platform.dispatch_key, -) - - def unified_attention_with_output( query: torch.Tensor, key: torch.Tensor, @@ -307,10 +298,19 @@ def unified_attention_with_output_fake( return -direct_register_custom_op( - op_name="unified_attention_with_output", - op_func=unified_attention_with_output, - mutates_args=["kv_cache", "output"], - fake_impl=unified_attention_with_output_fake, - dispatch_key=current_platform.dispatch_key, -) +def register_custom_ops(): + """Register custom ops for attention.""" + direct_register_custom_op( + op_name="unified_attention", + op_func=unified_attention, + mutates_args=["kv_cache"], + fake_impl=unified_attention_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( + op_name="unified_attention_with_output", + op_func=unified_attention_with_output, + mutates_args=["kv_cache", "output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, + ) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 419237c252ffd..23b7e40e4f37c 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,123 +1,50 @@ +from typing import Any + from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform - -current_platform: Platform - -# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because -# they only indicate the build configuration, not the runtime environment. -# For example, people can install a cuda build of pytorch but run on tpu. - -is_tpu = False -try: - # While it's technically possible to install libtpu on a non-TPU machine, - # this is a very uncommon scenario. Therefore, we assume that libtpu is - # installed if and only if the machine has TPUs. - import libtpu # noqa: F401 - is_tpu = True -except Exception: - pass - -is_cuda = False - -try: - import pynvml - pynvml.nvmlInit() - try: - if pynvml.nvmlDeviceGetCount() > 0: - is_cuda = True - finally: - pynvml.nvmlShutdown() -except Exception: - # CUDA is supported on Jetson, but NVML may not be. - import os - - def cuda_is_jetson() -> bool: - return os.path.isfile("/etc/nv_tegra_release") \ - or os.path.exists("/sys/class/tegra-firmware") - - if cuda_is_jetson(): - is_cuda = True - -is_rocm = False - -try: - import amdsmi - amdsmi.amdsmi_init() - try: - if len(amdsmi.amdsmi_get_processor_handles()) > 0: - is_rocm = True - finally: - amdsmi.amdsmi_shut_down() -except Exception: - pass - -is_hpu = False -try: - from importlib import util - is_hpu = util.find_spec('habana_frameworks') is not None -except Exception: - pass - -is_xpu = False - -try: - # installed IPEX if the machine has XPUs. - import intel_extension_for_pytorch # noqa: F401 - import oneccl_bindings_for_pytorch # noqa: F401 - import torch - if hasattr(torch, 'xpu') and torch.xpu.is_available(): - is_xpu = True -except Exception: - pass - -is_cpu = False -try: - from importlib.metadata import version - is_cpu = "cpu" in version("vllm") -except Exception: - pass - -is_neuron = False -try: - import transformers_neuronx # noqa: F401 - is_neuron = True -except ImportError: - pass - -is_openvino = False -try: - from importlib.metadata import version - is_openvino = "openvino" in version("vllm") -except Exception: - pass - -if is_tpu: - # people might install pytorch built with cuda but run on tpu - # so we need to check tpu first - from .tpu import TpuPlatform - current_platform = TpuPlatform() -elif is_cuda: - from .cuda import CudaPlatform - current_platform = CudaPlatform() -elif is_rocm: - from .rocm import RocmPlatform - current_platform = RocmPlatform() -elif is_hpu: - from .hpu import HpuPlatform - current_platform = HpuPlatform() -elif is_xpu: - from .xpu import XPUPlatform - current_platform = XPUPlatform() -elif is_cpu: - from .cpu import CpuPlatform - current_platform = CpuPlatform() -elif is_neuron: - from .neuron import NeuronPlatform - current_platform = NeuronPlatform() -elif is_openvino: - from .openvino import OpenVinoPlatform - current_platform = OpenVinoPlatform() -else: - current_platform = UnspecifiedPlatform() +from .registry import PlatformRegistry, detect_current_platform + + +class CurrentPlatform(Platform): + """A wrapper that provides an interface to the current platform. + + `current_platform` is imported to many modules once vLLM is imported. + Updating `current_platform` value directly will not work in those modules. + So it needs the wrapper here to provide a dynamic platform loading + mechanism. + + This class can make sure that the `current_platform` is always up-to-date. + """ + + def __init__(self): + self.platform: Platform = UnspecifiedPlatform() + + def initialize_current_platform(self): + """Initialize the current platform. This function is called when loading + the vllm plugin.""" + # Get the current platform from the registry first. If the current + # platform is not set, try to detect the current platform. + if PlatformRegistry.current_platform is not None: + self.platform = PlatformRegistry.get_current_platform_cls()() + else: + self.platform = detect_current_platform() + + # Register custom ops for the current platform. + from vllm.attention.layer import register_custom_ops + register_custom_ops() + + def __getattr__(self, name: str) -> Any: + """Get the attribute. If the attribute is not found, go pass to the + current platform.""" + if name == 'platform': + return self.platform + if name == 'initialize_current_platform': + return self.initialize_current_platform + # Go pass to the current platform. + return getattr(self.platform, name) + + +# The global variable for other modules to use. +current_platform = CurrentPlatform() __all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum'] diff --git a/vllm/platforms/registry.py b/vllm/platforms/registry.py new file mode 100644 index 0000000000000..946e960e651ab --- /dev/null +++ b/vllm/platforms/registry.py @@ -0,0 +1,171 @@ +from dataclasses import dataclass, field +from typing import Callable, Dict, Optional + +from vllm import utils +from vllm.platforms import Platform + +from .interface import UnspecifiedPlatform + +# The list of supported in-tree platforms. Update this list when adding/removing +# platforms. +_VLLM_PLATFORMS = { + "cpu": "vllm.platforms.cpu:CpuPlatform", + "cuda": "vllm.platforms.cuda:CudaPlatform", + "hpu": "vllm.platforms.hpu:HpuPlatform", + "neuron": "vllm.platforms.neuron:NeuronPlatform", + "openvino": "vllm.platforms.openvino:OpenVinoPlatform", + "rocm": "vllm.platforms.rocm:RocmPlatform", + "tpu": "vllm.platforms.tpu:TpuPlatform", + "xpu": "vllm.platforms.xpu:XPUPlatform", +} + + +@dataclass +class _PlatformRegistry: + # The mapping from device name to platform class string. + platforms: Dict[str, str] = field(default_factory=dict) + # The current platform name. + current_platform: Optional[str] = None + + def _load_platform_cls(self, device_name: str) -> Callable: + """Load a platform object by device name.""" + if device_name not in self.platforms: + raise ValueError( + f"Platform {device_name} not registered. " + f"Available platforms: {list(self.platforms.keys())}") + platform_cls_str = self.platforms[device_name] + return utils.resolve_obj_by_qualname(platform_cls_str) + + def register_platform(self, device_name: str, platform: str): + """Register a platform by device name. This function is called by the + platform plugin.""" + if device_name in self.platforms: + raise ValueError(f"Platform {device_name} already registered.") + self.platforms[device_name] = platform + + def set_current_platform(self, device_name: str): + """Set the current platform by device name.""" + if device_name not in self.platforms: + raise ValueError( + f"Platform {device_name} not registered. " + f"Available platforms: {list(self.platforms.keys())}") + self.current_platform = device_name + + def get_current_platform_cls(self) -> Callable: + """Get the current platform object.""" + if self.current_platform is None: + raise ValueError("No current platform set.") + return self._load_platform_cls(self.current_platform) + + +PlatformRegistry = _PlatformRegistry({ + device_name: platform + for device_name, platform in _VLLM_PLATFORMS.items() +}) + + +def detect_current_platform() -> Platform: + """Detect the current platform by checking the installed packages.""" + CurrentPlatform: Optional[type[Platform]] = None + # NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because + # they only indicate the build configuration, not the runtime environment. + # For example, people can install a cuda build of pytorch but run on tpu. + + # Load TPU Platform + try: + # While it's technically possible to install libtpu on a non-TPU + # machine, this is a very uncommon scenario. Therefore, we assume that + # libtpu is installed if and only if the machine has TPUs. + import libtpu # noqa: F401 + + from .tpu import TpuPlatform as CurrentPlatform + except Exception: + pass + + # Load CUDA Platform + if not CurrentPlatform: + try: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + from .cuda import CudaPlatform as CurrentPlatform + finally: + pynvml.nvmlShutdown() + except Exception: + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") \ + or os.path.exists("/sys/class/tegra-firmware") + + if cuda_is_jetson(): + from .cuda import CudaPlatform as CurrentPlatform + + # Load ROCm Platform + if not CurrentPlatform: + try: + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + from .rocm import RocmPlatform as CurrentPlatform + finally: + amdsmi.amdsmi_shut_down() + except Exception: + pass + + # Load HPU Platform + if not CurrentPlatform: + try: + from importlib import util + assert util.find_spec('habana_frameworks') is not None + from .hpu import HpuPlatform as CurrentPlatform + except Exception: + pass + + # Load XPU Platform + if not CurrentPlatform: + try: + # installed IPEX if the machine has XPUs. + import intel_extension_for_pytorch # noqa: F401 + import oneccl_bindings_for_pytorch # noqa: F401 + import torch + if hasattr(torch, 'xpu') and torch.xpu.is_available(): + from .xpu import XPUPlatform as CurrentPlatform + except Exception: + pass + + # Load CPU Platform + if not CurrentPlatform: + try: + from importlib.metadata import version + assert "cpu" in version("vllm") + from .cpu import CpuPlatform as CurrentPlatform + except Exception: + pass + + # Load Neuron Platform + if not CurrentPlatform: + try: + import transformers_neuronx # noqa: F401 + + from .neuron import NeuronPlatform as CurrentPlatform + except ImportError: + pass + + # Load OpenVINO Platform + if not CurrentPlatform: + try: + from importlib.metadata import version + assert "openvino" in version("vllm") + from .openvino import OpenVinoPlatform as CurrentPlatform + except Exception: + pass + + if CurrentPlatform: + PlatformRegistry.set_current_platform(CurrentPlatform.device_name) + return CurrentPlatform() + + return UnspecifiedPlatform() diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 17f604ea0e202..051306a3812dc 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -17,6 +17,44 @@ def load_general_plugins(): processes. They should be designed in a way that they can be loaded multiple times without causing issues. """ + global plugins_loaded + if not plugins_loaded: + import sys + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group='vllm.general_plugins') + if len(discovered_plugins) == 0: + logger.debug("No plugins found.") + else: + logger.info("Available plugins:") + for plugin in discovered_plugins: + logger.info("name=%s, value=%s, group=%s", plugin.name, + plugin.value, plugin.group) + if allowed_plugins is None: + logger.info("all available plugins will be loaded.") + logger.info("set environment variable VLLM_PLUGINS to control" + " which plugins to load.") + else: + logger.info("plugins to load: %s", allowed_plugins) + for plugin in discovered_plugins: + if allowed_plugins is None or plugin.name in allowed_plugins: + try: + func = plugin.load() + func() + logger.info("plugin %s loaded.", plugin.name) + except Exception: + logger.exception("Failed to load plugin %s", + plugin.name) + # initialize current platform should be called after all plugins are + # loaded. + current_platform.initialize_current_platform() + + plugins_loaded = True # all processes created by vllm will load plugins, # and here we can inject some common environment variables @@ -42,38 +80,3 @@ def load_general_plugins(): # requires enabling lazy collectives # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' - - global plugins_loaded - if plugins_loaded: - return - plugins_loaded = True - import sys - if sys.version_info < (3, 10): - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points - - allowed_plugins = envs.VLLM_PLUGINS - - discovered_plugins = entry_points(group='vllm.general_plugins') - if len(discovered_plugins) == 0: - logger.debug("No plugins found.") - return - logger.info("Available plugins:") - for plugin in discovered_plugins: - logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value, - plugin.group) - if allowed_plugins is None: - logger.info("all available plugins will be loaded.") - logger.info("set environment variable VLLM_PLUGINS to control" - " which plugins to load.") - else: - logger.info("plugins to load: %s", allowed_plugins) - for plugin in discovered_plugins: - if allowed_plugins is None or plugin.name in allowed_plugins: - try: - func = plugin.load() - func() - logger.info("plugin %s loaded.", plugin.name) - except Exception: - logger.exception("Failed to load plugin %s", plugin.name)