Skip to content

Commit

Permalink
Add platform pluggable
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 16, 2024
1 parent 69ba344 commit 3ec575e
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 142 deletions.
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_platform:register"]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm import PlatformRegistry


def register():
# Register the dummy platform
PlatformRegistry.register(
"my_platform", "vllm_add_dummy_platform.my_platform:DummyPlatform")
# Set the current platform to the dummy platform
PlatformRegistry.set_current_platform("my_platform")
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from vllm.platforms import Platform


class DummyPlatform(Platform):

def __init__(self):
super().__init__()

def get_device_name(self) -> str:
return "dummy"
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.platforms.registry import PlatformRegistry
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

Expand All @@ -22,6 +23,7 @@
"__version_tuple__",
"LLM",
"ModelRegistry",
"PlatformRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
Expand Down
127 changes: 12 additions & 115 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,20 @@
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
from .interface import CpuArchEnum, Platform, PlatformEnum
from .registry import PlatformRegistry, detect_current_platform

current_platform: Platform

# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass
def initialize_current_platform():
"""Initialize the current platform. This function is called when loading
the vllm plugin."""
global current_platform
# Get the current platform from the registry first. If the current platform
# is not set, try to detect the current platform.
if PlatformRegistry.current_platform is not None:
current_platform = PlatformRegistry.get_current_platform()
else:
current_platform = detect_current_platform()

is_cuda = False

try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True

is_rocm = False

try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass

is_hpu = False
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
pass

is_xpu = False

try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
except Exception:
pass

is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
except Exception:
pass

is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass

is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass

if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()

__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
170 changes: 170 additions & 0 deletions vllm/platforms/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from dataclasses import dataclass, field
from typing import Dict, Optional

from vllm.platforms import Platform

from .interface import UnspecifiedPlatform

_VLLM_PLATFORMS = {
"cpu": "vllm.platforms.cpu:CpuPlatform",
"cuda": "vllm.platforms.cuda:CudaPlatform",
"hpu": "vllm.platforms.hpu:HpuPlatform",
"neuron": "vllm.platforms.neuron:NeuronPlatform",
"openvino": "vllm.platforms.openvino:OpenVinoPlatform",
"rocm": "vllm.platforms.rocm:RocmPlatform",
"tpu": "vllm.platforms.tpu:TpuPlatform",
"xpu": "vllm.platforms.xpu:XPUPlatform",
}


@dataclass
class _PlatformRegistry:
platforms: Dict[str, str] = field(default_factory=dict)
current_platform: Optional[str] = None

def register(self, device_name: str, platform: str):
"""Register a platform by device name. This function is called by the
platform plugin."""
if device_name in self.platforms:
raise ValueError(f"Platform {device_name} already registered.")
self.platforms[device_name] = platform

def load_platform_cls(self, device_name: str) -> Platform:
"""Load a platform object by device name."""
if device_name not in self.platforms:
raise ValueError(
f"Platform {device_name} not registered. "
f"Available platforms: {list(self.platforms.keys())}")
platform_cls_str = self.platforms[device_name]
module_name, cls_name = platform_cls_str.split(":")
module = __import__(module_name, fromlist=[cls_name])
return getattr(module, cls_name)

def set_current_platform(self, device_name: str):
"""Set the current platform by device name."""
if device_name not in self.platforms:
raise ValueError(
f"Platform {device_name} not registered. "
f"Available platforms: {list(self.platforms.keys())}")
self.current_platform = device_name

def get_current_platform(self) -> Platform:
"""Get the current platform object."""
if self.current_platform is None:
raise ValueError("No current platform set.")
return self.load_platform_cls(self.current_platform)


PlatformRegistry = _PlatformRegistry({
device_name: platform
for device_name, platform in _VLLM_PLATFORMS.items()
})


def detect_current_platform() -> Platform:
"""Detect the current platform by checking the installed packages."""
CurrentPlatform: Optional[type[Platform]] = None
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

# Load TPU Platform
try:
# While it's technically possible to install libtpu on a non-TPU
# machine, this is a very uncommon scenario. Therefore, we assume that
# libtpu is installed if and only if the machine has TPUs.
import libtpu # noqa: F401

from .tpu import TpuPlatform as CurrentPlatform
except Exception:
pass

# Load CUDA Platform
if not CurrentPlatform:
try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
from .cuda import CudaPlatform as CurrentPlatform
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
from .cuda import CudaPlatform as CurrentPlatform

# Load ROCm Platform
if not CurrentPlatform:
try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
from .rocm import RocmPlatform as CurrentPlatform
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass

# Load HPU Platform
if not CurrentPlatform:
try:
from importlib import util
assert util.find_spec('habana_frameworks') is not None
from .hpu import HpuPlatform as CurrentPlatform
except Exception:
pass

# Load XPU Platform
if not CurrentPlatform:
try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
from .xpu import XPUPlatform as CurrentPlatform
except Exception:
pass

# Load CPU Platform
if not CurrentPlatform:
try:
from importlib.metadata import version
assert "cpu" in version("vllm")
from .cpu import CpuPlatform as CurrentPlatform
except Exception:
pass

# Load Neuron Platform
if not CurrentPlatform:
try:
import transformers_neuronx # noqa: F401

from .neuron import NeuronPlatform as CurrentPlatform
except ImportError:
pass

# Load OpenVINO Platform
if not CurrentPlatform:
try:
from importlib.metadata import version
assert "openvino" in version("vllm")
from .openvino import OpenVinoPlatform as CurrentPlatform
except Exception:
pass

if CurrentPlatform:
device_name = CurrentPlatform.get_device_name()
PlatformRegistry.set_current_platform(device_name)
current_platform = CurrentPlatform()
else:
current_platform = UnspecifiedPlatform()
return current_platform
Loading

0 comments on commit 3ec575e

Please sign in to comment.