Skip to content

Commit

Permalink
[Platform] platform agnostic for EngineArgs initialization (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#11225)

Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan authored Dec 17, 2024
1 parent 59c9b6e commit e88db68
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 6 deletions.
8 changes: 2 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ class EngineArgs:
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
block_size: int = 16 if not current_platform.is_hpu() else 128
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
Expand Down Expand Up @@ -1036,9 +1034,7 @@ def create_engine_config(self,
self.enable_prefix_caching = False

cache_config = CacheConfig(
# neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else
(self.max_model_len if self.max_model_len is not None else 0),
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

cache_config = vllm_config.cache_config

if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE

if kv_cache_space >= 0:
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"

# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128

@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"

cache_config = vllm_config.cache_config
if cache_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len

@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# check and update cache config
ov_core = ov.Core()
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
if not OpenVinoPlatform.is_openvino_cpu():
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def inference_mode(cls):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel

cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

compilation_config = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def inference_mode():

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

# check and update model config
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
Expand Down

0 comments on commit e88db68

Please sign in to comment.