Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: vllm failed to run two instance with one gpu #10533

Closed
1 task done
pandada8 opened this issue Nov 21, 2024 · 3 comments
Closed
1 task done

[Bug]: vllm failed to run two instance with one gpu #10533

pandada8 opened this issue Nov 21, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@pandada8
Copy link

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.12.7 (main, Oct  1 2024, 08:52:12) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-73-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.113.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          16
On-line CPU(s) list:             0-15
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz
CPU family:                      6
Model:                           106
Thread(s) per core:              2
Core(s) per socket:              8
Socket(s):                       1
Stepping:                        6
BogoMIPS:                        5799.99
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm arch_capabilities
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       384 KiB (8 instances)
L1i cache:                       256 KiB (8 instances)
L2 cache:                        10 MiB (8 instances)
L3 cache:                        48 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-15
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] flashinfer==0.1.6+cu121torch2.4
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.46.2
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.4.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      0-15            N/A             N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NVIDIA_VISIBLE_DEVICES=all
NVIDIA_REQUIRE_CUDA=cuda>=12.4 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 brand=tesla,driver>=535,driver<536 brand=unknown,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=geforce,driver>=535,driver<536 brand=geforcertx,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=titan,driver>=535,driver<536 brand=titanrtx,driver>=535,driver<536
NVIDIA_DRIVER_CAPABILITIES=compute,utility
VLLM_USAGE_SOURCE=production-docker-image
CUDA_VERSION=12.4.1
LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/cv2/../../lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
CUDA_MODULE_LOADING=LAZY

Model Input Dumps

No response

🐛 Describe the bug

I want to run Qwen2.5-14B-Instruct-GPTQ-Int4 and Qwen2.5-72B-Instruct-GPTQ-Int4 with one 80G A100.

with following command

python3 -m vllm.entrypoints.openai.api_server \
        --model /models/Qwen2.5-72B-Instruct-GPTQ-Int4 \
        --gpu-memory-utilization 0.75 \
        --max-model-len=16384 \
        --served-model-name qwen2.5-14b \
        --quantization=gptq \
        --enforce-eager \
        --enable-chunked-prefill \
        --enable-prefix-caching
python3 -m vllm.entrypoints.openai.api_server \
        --model /models/Qwen2.5-14B-Instruct-GPTQ-Int4 \
        --gpu-memory-utilization 0.20 \
        --max-model-len=16384 \
        --served-model-name qwen2.5-14b \
        --quantization=gptq \
        --enforce-eager \
        --enable-chunked-prefill \
        --enable-prefix-caching

Both commands succeed when running separately, consuming approximately 75% and 20% of the vRAM respectively.
However when you start the 72B first and then 14b. the 14b will refused to start with Exception

ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.

Note that the 14B vllm server outputs the following log

Memory profiling results: total_gpu_memory=79.15GiB initial_memory_usage=68.01GiB peak_torch_memory=10.77GiB memory_usage_post_profile=68.04Gib non_torch_memory=58.66GiB kv_cache_size=-53.60GiB gpu_memory_utilization=0.20

which has negative kv_cache_size.

it looks like currently

vllm/vllm/worker/worker.py

Lines 213 to 219 in 1cfde82

non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)

peak_memory would include the other vllm used vram and result in negative available_kv_cache_memory

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@pandada8 pandada8 added the bug Something isn't working label Nov 21, 2024
@pandada8 pandada8 changed the title [Bug]: [Bug]: vllm failed to run two instance with one gpu Nov 21, 2024
@pandada8
Copy link
Author

pandada8 commented Nov 21, 2024

#2248 (comment) suggest counting used vRAM before loading model and exclude it when calculating the available_kv_cache_memory. it seems works well when you staggered startup different vLLM worker in the same card

@pandada8
Copy link
Author

proposed change

diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 80fd7bc3..793812e5 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -150,6 +150,8 @@ class Worker(LocalOrDistributedWorkerBase):
         set_random_seed(self.model_config.seed)
 
     def load_model(self):
+        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
+        self.gpu_mem_pre_occupied = total_gpu_memory - free_gpu_memory
         self.model_runner.load_model()
 
     def save_sharded_state(
@@ -214,9 +216,7 @@ class Worker(LocalOrDistributedWorkerBase):
         if non_torch_allocations > 0:
             peak_memory += non_torch_allocations
 
-        available_kv_cache_memory = (
-            total_gpu_memory * self.cache_config.gpu_memory_utilization -
-            peak_memory)
+        available_kv_cache_memory = total_gpu_memory * gpu_memory_utilization - (peak_memory - self.gpu_mem_pre_occupied)
 
         # Calculate the number of blocks that can be allocated with the
         # profiled peak memory.

@DarkLight1337
Copy link
Member

Please see #10511

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants
@pandada8 @DarkLight1337 and others