Skip to content

Commit

Permalink
[platforms] enable platform plugins (#11602)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Dec 30, 2024
1 parent 5dbf854 commit b12e87f
Show file tree
Hide file tree
Showing 23 changed files with 360 additions and 181 deletions.
25 changes: 19 additions & 6 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,12 @@ steps:
source_file_dependencies:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests

Expand Down Expand Up @@ -333,8 +331,6 @@ steps:
- vllm/
- tests/models
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py

Expand Down Expand Up @@ -469,11 +465,28 @@ steps:
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py

- label: Plugin Tests (2 GPUs) # 40min
working_dir: "/vllm-workspace/tests"
num_gpus: 2
fast_check: true
source_file_dependencies:
- vllm/plugins/
- tests/plugins/
commands:
# begin platform plugin tests, all the code in-between runs on dummy platform
- pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y
# end platform plugin tests
# other tests continue here:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process

- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
Expand Down
6 changes: 4 additions & 2 deletions docs/source/design/plugin_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ Every plugin has three parts:
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.

## What Can Plugins Do?
## Types of supported plugins

Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
- **General plugins** (with group name `vllm.general_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model inside the plugin function.

- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.

## Guidelines for Writing Plugins

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity)
Expand Down Expand Up @@ -242,6 +241,7 @@ def video_assets() -> _VideoAssets:
class HfRunner:

def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
from vllm.platforms import current_platform
if x is None or isinstance(x, (bool, )):
return x

Expand Down
16 changes: 8 additions & 8 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.platforms import cpu, cuda, openvino, rocm
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL


Expand All @@ -20,26 +23,23 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)

if device == "cpu":
with patch("vllm.attention.selector.current_platform",
cpu.CpuPlatform()):
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
rocm.RocmPlatform()):
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
openvino.OpenVinoPlatform()):
OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
with patch("vllm.attention.selector.current_platform",
cuda.CudaPlatform()):
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name
Expand Down
11 changes: 11 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from setuptools import setup

setup(
name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.platform_plugins': [
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Optional


def dummy_platform_plugin() -> Optional[str]:
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vllm.platforms.cuda import CudaPlatform


class DummyPlatform(CudaPlatform):
device_name = "DummyDevice"
16 changes: 16 additions & 0 deletions tests/plugins_tests/test_platform_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def test_platform_plugins():
# simulate workload by running an example
import runpy
current_file = __file__
import os
example_file = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(current_file))),
"examples", "offline_inference.py")
runpy.run_path(example_file)

# check if the plugin is loaded correctly
from vllm.platforms import _init_trace, current_platform
assert current_platform.device_name == "DummyDevice", (
f"Expected DummyDevice, got {current_platform.device_name}, "
"possibly because current_platform is imported before the plugin"
f" is loaded. The first import:\n{_init_trace}")
15 changes: 12 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform, interface
from vllm.platforms import CpuArchEnum
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
Expand Down Expand Up @@ -349,6 +349,7 @@ def __init__(self,
self.is_hybrid = self._init_is_hybrid()
self.has_inner_state = self._init_has_inner_state()

from vllm.platforms import current_platform
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:
Expand Down Expand Up @@ -589,6 +590,7 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
from vllm.platforms import current_platform
current_platform.verify_quantization(self.quantization)
if self.quantization not in optimized_quantization_methods:
logger.warning(
Expand Down Expand Up @@ -644,6 +646,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,

# Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid
from vllm.platforms import current_platform
if not current_platform.is_async_output_supported(self.enforce_eager):
logger.warning(
"Async output processing is not supported on the "
Expand Down Expand Up @@ -1012,6 +1015,7 @@ def _verify_args(self) -> None:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
from vllm.platforms import current_platform
if (current_platform.is_cuda() and self.block_size is not None
and self.block_size > 32):
raise ValueError("CUDA Paged Attention kernel only supports "
Expand Down Expand Up @@ -1279,6 +1283,7 @@ def __post_init__(self) -> None:
f"distributed executor backend "
f"'{self.distributed_executor_backend}'.")
ray_only_devices = ["tpu", "hpu"]
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices
and self.world_size > 1):
if self.distributed_executor_backend is None:
Expand Down Expand Up @@ -1327,7 +1332,7 @@ def use_ray(self) -> bool:
def _verify_args(self) -> None:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase

from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
Expand Down Expand Up @@ -1528,6 +1533,7 @@ def compute_hash(self) -> str:
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError("Failed to infer device type")
Expand Down Expand Up @@ -2241,9 +2247,10 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype

from vllm.platforms import current_platform
if (current_platform.is_cpu()
and current_platform.get_cpu_architecture()
== interface.CpuArchEnum.POWERPC
== CpuArchEnum.POWERPC
and (config_dtype == torch.float16
or config_dtype == torch.float32)):
logger.info(
Expand Down Expand Up @@ -3083,6 +3090,7 @@ def _get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
from vllm.platforms import current_platform
if model_config.quantization is not None:
from vllm.model_executor.model_loader.weight_utils import (
get_quant_config)
Expand Down Expand Up @@ -3145,6 +3153,7 @@ def __post_init__(self):
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)

from vllm.platforms import current_platform
if self.scheduler_config is not None and \
self.model_config is not None and \
self.scheduler_config.chunked_prefill_enabled and \
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op

if TYPE_CHECKING:
Expand Down Expand Up @@ -194,6 +193,7 @@ def __init__(
assert self.cpu_group is not None
assert self.device_group is not None

from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
Expand Down Expand Up @@ -1188,6 +1188,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
from vllm.platforms import current_platform
if not current_platform.is_cpu():
torch.cuda.empty_cache()

Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean
Expand Down Expand Up @@ -1094,6 +1093,7 @@ def create_engine_config(self,
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
from vllm.platforms import current_platform
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
Expand Down
2 changes: 1 addition & 1 deletion vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
Expand Down Expand Up @@ -229,6 +228,7 @@ def initialize_ray_cluster(
the default Ray cluster address.
"""
assert_ray_available()
from vllm.platforms import current_platform

# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms import CpuArchEnum

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -39,6 +39,7 @@ def maybe_backend_fallback(

if guided_params.backend == "xgrammar":
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
logger.warning("xgrammar is only supported on x86 CPUs. "
"Falling back to use outlines instead.")
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch.nn as nn

from vllm.logger import init_logger
from vllm.platforms import current_platform

from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
Expand Down Expand Up @@ -273,6 +272,7 @@ def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
from vllm.platforms import current_platform
current_platform.verify_model_arch(model_arch)
try:
return model.load_model_cls()
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import torch

from vllm.platforms import current_platform


def set_random_seed(seed: int) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)


Expand Down Expand Up @@ -38,6 +37,7 @@ def set_weight_attrs(
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from vllm.platforms import current_platform
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
setattr(weight, key, value)
Expand Down
Loading

0 comments on commit b12e87f

Please sign in to comment.