Skip to content

Commit

Permalink
Add platform pluggable
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 24, 2024
1 parent a491d6f commit 62b074a
Show file tree
Hide file tree
Showing 14 changed files with 374 additions and 171 deletions.
10 changes: 10 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ steps:
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests

- label: Platform Plugin Test
working_dir: "/vllm-workspace/tests"
fast_check: true
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s platform/test_platform_plugin.py

- label: Distributed Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
Expand Down
2 changes: 1 addition & 1 deletion docs/source/design/plugin_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Every plugin has three parts:

## What Can Plugins Do?

Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
Currently, the primary use case for plugins is to register custom, out-of-the-tree models or platforms into vLLM. This is done by calling `ModelRegistry.register_model` or `PlatformRegistry.register_platform` to register the model or platform. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.

## Guidelines for Writing Plugins

Expand Down
13 changes: 13 additions & 0 deletions tests/platform/test_platform_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

Check failure on line 1 in tests/platform/test_platform_plugin.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/platform/test_platform_plugin.py:1:8: F401 `pytest` imported but unused

from vllm_add_dummy_platform.my_platform import DummyPlatform

from vllm.platforms import PlatformRegistry, current_platform


def test_current_platform_register():
# make sure the platform is registered
assert PlatformRegistry.current_platform == "my_platform"
# make sure the platform is loaded
assert current_platform.device_name == "dummy"
assert type(current_platform.platform) == DummyPlatform

Check failure on line 13 in tests/platform/test_platform_plugin.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E721)

tests/platform/test_platform_plugin.py:13:12: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_platform:register"]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm import PlatformRegistry


def register():
# Register the dummy platform
PlatformRegistry.register_platform(
"my_platform", "vllm_add_dummy_platform.my_platform.DummyPlatform")
# Set the current platform to the dummy platform
PlatformRegistry.set_current_platform("my_platform")
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class DummyAttentionImpl:

def forward(self):
pass


class DummyAttentionBackend:

def __init__(self):
pass

def get_impl_cls(self):
return DummyAttentionImpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from my_attention import DummyAttentionBackend


class DummyModelRunner:

def __init__(self):
self.attn_backend = DummyAttentionBackend()
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from vllm.config import VllmConfig
from vllm.platforms import Platform, PlatformEnum


class DummyPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
device_name = "dummy"

def __init__(self):
super().__init__()

@classmethod
def get_device_name(cls) -> str:
return "dummy"

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = \
"vllm_add_dummy_platform.my_worker.DummyWorker"
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List

from my_model_runner import DummyModelRunner


class DummyCacheEngine:
pass


class DummyWorker:

def __init__(self):
self.cache_engine = List[DummyCacheEngine]
self.model_runner = DummyModelRunner()
7 changes: 7 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.platforms.registry import PlatformRegistry
from vllm.plugins import load_general_plugins
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

from .version import __version__, __version_tuple__

# Load general plugins first when the module is imported to make sure that all
# necessary global variables are set. Such as the `current_platform`.
load_general_plugins()

__all__ = [
"__version__",
"__version_tuple__",
"LLM",
"ModelRegistry",
"PlatformRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
Expand Down
32 changes: 16 additions & 16 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,6 @@ def unified_attention_fake(
return torch.empty_like(query).contiguous()


direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)


def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -307,10 +298,19 @@ def unified_attention_with_output_fake(
return


direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
def register_custom_ops():
"""Register custom ops for attention."""
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
165 changes: 46 additions & 119 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,50 @@
from typing import Any

from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform

current_platform: Platform

# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass

is_cuda = False

try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True

is_rocm = False

try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass

is_hpu = False
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
pass

is_xpu = False

try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
except Exception:
pass

is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
except Exception:
pass

is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass

is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass

if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()
from .registry import PlatformRegistry, detect_current_platform


class CurrentPlatform(Platform):
"""A wrapper that provides an interface to the current platform.
`current_platform` is imported to many modules once vLLM is imported.
Updating `current_platform` value directly will not work in those modules.
So it needs the wrapper here to provide a dynamic platform loading
mechanism.
This class can make sure that the `current_platform` is always up-to-date.
"""

def __init__(self):
self.platform: Platform = UnspecifiedPlatform()

def initialize_current_platform(self):
"""Initialize the current platform. This function is called when loading
the vllm plugin."""
# Get the current platform from the registry first. If the current
# platform is not set, try to detect the current platform.
if PlatformRegistry.current_platform is not None:
self.platform = PlatformRegistry.get_current_platform_cls()()
else:
self.platform = detect_current_platform()

# Register custom ops for the current platform.
from vllm.attention.layer import register_custom_ops
register_custom_ops()

def __getattr__(self, name: str) -> Any:
"""Get the attribute. If the attribute is not found, go pass to the
current platform."""
if name == 'platform':
return self.platform
if name == 'initialize_current_platform':
return self.initialize_current_platform
# Go pass to the current platform.
return getattr(self.platform, name)


# The global variable for other modules to use.
current_platform = CurrentPlatform()

__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
Loading

0 comments on commit 62b074a

Please sign in to comment.