Skip to content

Commit

Permalink
[misc] use nvml to get consistent device name (vllm-project#7582)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Aug 17, 2024
1 parent 7c0b7ea commit eed020f
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 2 deletions.
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand Down Expand Up @@ -287,7 +288,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,


def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"

Expand Down
34 changes: 34 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,35 @@ def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_name(device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)


@with_nvml_context
def warn_if_different_devices():
device_ids: int = pynvml.nvmlDeviceGetCount()
if device_ids > 1:
device_names = [get_physical_device_name(i) for i in range(device_ids)]
if len(set(device_names)) > 1 and os.environ.get(
"CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
logger.warning(
"Detected different devices in the system: \n%s\nPlease"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.", "\n".join(device_names))


try:
from sphinx.ext.autodoc.mock import _MockModule

if not isinstance(pynvml, _MockModule):
warn_if_different_devices()
except ModuleNotFoundError:
warn_if_different_devices()


def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
Expand All @@ -61,6 +90,11 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id)

@staticmethod
def get_device_name(device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id)

@staticmethod
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def is_tpu(self) -> bool:
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError

@staticmethod
def get_device_name(device_id: int = 0) -> str:
raise NotImplementedError

@staticmethod
def inference_mode():
"""A device-specific wrapper of `torch.inference_mode`.
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ class RocmPlatform(Platform):
@lru_cache(maxsize=8)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device_id)

@staticmethod
@lru_cache(maxsize=8)
def get_device_name(device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
2 changes: 1 addition & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16:
compute_capability = current_platform.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
gpu_name = current_platform.get_device_name()
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
Expand Down

0 comments on commit eed020f

Please sign in to comment.