Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: The random seed behavior when loading a model in vLLM is confusing. #11953

Open
1 task done
Aratako opened this issue Jan 11, 2025 · 2 comments
Open
1 task done
Labels
bug Something isn't working

Comments

@Aratako
Copy link

Aratako commented Jan 11, 2025

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.11.10 (main, Sep  7 2024, 18:35:41) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.8.0-50-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti
Nvidia driver version: 550.142
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               12
On-line CPU(s) list:                  0-11
Vendor ID:                            AuthenticAMD
Model name:                           AMD Ryzen 5 5500
CPU family:                           25
Model:                                80
Thread(s) per core:                   2
Core(s) per socket:                   6
Socket(s):                            1
Stepping:                             0
CPU max MHz:                          4267.0000
CPU min MHz:                          400.0000
BogoMIPS:                             7186.55
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
L1d cache:                            192 KiB (6 instances)
L1i cache:                            192 KiB (6 instances)
L2 cache:                             3 MiB (6 instances)
L3 cache:                             16 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-11
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==24.0.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.4.1+cu124
[pip3] torchvision==0.20.1
[pip3] transformers==4.48.0
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.6.post2.dev178+g7a3a83e3
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      0-11    0               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NVIDIA_VISIBLE_DEVICES=GPU-5fcb4f33-7dff-760b-bf25-7a010fdd0865
NVIDIA_REQUIRE_CUDA=cuda>=12.4 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 brand=tesla,driver>=535,driver<536 brand=unknown,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=geforce,driver>=535,driver<536 brand=geforcertx,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=titan,driver>=535,driver<536 brand=titanrtx,driver>=535,driver<536
NCCL_VERSION=2.21.5-1
NVIDIA_DRIVER_CAPABILITIES=compute,display,graphics,utility,video
NVIDIA_PRODUCT_NAME=CUDA
CUDA_VERSION=12.4.1
LD_LIBRARY_PATH=/usr/local/lib/python3.11/dist-packages/cv2/../../lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
CUDA_MODULE_LOADING=LAZY

Model Input Dumps

No response

🐛 Describe the bug

Description

When loading a model in vLLM, the seed parameter unintentionally affects the global random states (random, np.random), which can lead to surprising behavior if the user is not explicitly aware of it.

Specifically:

  1. If the user does not specify the seed parameter, its default value (0) is used, and the global random states are set accordingly here.
  2. This means the behavior of random operations in user code outside vLLM becomes unintentionally fixed, which can cause subtle bugs.

For example, if the user assumes that random values (e.g., generated with random.choice or np.random.rand) will vary across runs, they might encounter identical results across multiple script executions.

Steps to Reproduce

Here is a minimal example illustrating the issue:

import random
from vllm import LLM

# Initialize a vLLM model
model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct")

# Try generating random numbers
print(random.randint(0, 100))  # Outputs the same number every time the script is run

Additionally, a practical scenario where users might encounter this issue is during synthetic data generation, where random seed texts are used to create prompts for a model:

import random
from vllm import LLM, SamplingParams

# 1) Load the model
model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct")

# 2) Load a dataset of seed texts
seed_dataset = ["text1", "text2", "text3"]

# 3) Randomly select a seed text to create prompts
def prepare_prompts(batch_size, dataset):
    prompts = []
    for _ in range(batch_size):
        seed_text = random.choice(dataset)
        prompts.append(f"Generate a QA pair based on: {seed_text}")
    return prompts

# 4) Generate outputs
prompts = prepare_prompts(3, seed_dataset)
outputs = model.generate(prompts, sampling_params=SamplingParams(temperature=0.7, max_tokens=100))

print(outputs)

In this scenario, random.choice is expected to select different seed texts across multiple runs of the script. However, due to the default seed=0, the global random state is fixed, causing the same seed text to be chosen every time. As a result, the model outputs identical results across multiple executions, which is counterintuitive for users expecting randomness.

Related Issues

  • Issue #8519: This issue discusses the same behavior, suggesting setting seed=None as a solution. However, this workaround leads to an error in the current version of vLLM (torch.manual_seed(seed) fails when seed=None).

Expected Behavior

  • If the seed parameter is not specified, the behavior of global random states should remain unaffected.
  • Alternatively, the documentation should clearly explain that the seed parameter modifies global random states and how users can control this behavior.

Actual Behavior

  • If the user does not explicitly specify a seed, the global random states for random and np.random are unintentionally set to the default value of seed=0 during model initialization.
  • This causes random operations outside vLLM to behave in a non-intuitive, fixed manner.

Proposed Solution

  1. Update the default value of the seed parameter to None.
  2. If seed=None, skip setting any random seed for random, np.random, and torch.manual_seed.
  3. Add clear documentation about how the seed parameter behaves and its effect on global random states.

Why This Matters

This behavior can be highly confusing for users, especially when working on tasks like synthetic data generation or other workflows involving randomness. Unexpectedly fixed random states can lead to subtle bugs and wasted debugging time.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@Aratako Aratako added the bug Something isn't working label Jan 11, 2025
@SirlyDreamer
Copy link

SirlyDreamer commented Jan 20, 2025

Looks like I'm experiencing this exact problem as well.

I ran identical vLLM data generation scripts across 8 H100 GPUs. Despite setting the temperature to 0.8 and not specifying any random seed, I got the exact same output from each GPU. This seems to defeat the purpose of using temperature for diversity in generation.

Environment:

Python Version: 3.11.10
CUDA Version: 12.4.1
PyTorch Version: 2.5.1
vLLM Version: 0.6.6.post1

Codes are here:

Generate Script:

import argparse
import json
import os
from vllm import LLM, SamplingParams
from datasets import load_dataset,Dataset
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str)
parser.add_argument('--tensor-parallel-size', type=int, default=1)
parser.add_argument('--dataset', type=str)
parser.add_argument('--save-path', type=str)
args = parser.parse_args()

model_base_name = os.path.basename(args.model)

data = load_dataset(args.dataset)['train']

sampling_params = SamplingParams(
    max_tokens=4096,
    temperature=0.8,
)

llm = LLM(model=args.model, tensor_parallel_size=args.tensor_parallel_size)

prompts = [[{"role": "user","content": d['instruction'],}] for d in data]
outputs = llm.chat(prompts, sampling_params)
results = []

for i, output in tqdm(enumerate(outputs), total=len(outputs)):
    results.append({
        "uuid":data['uuid'][i],
        "instruction":data['instruction'][i],
        "response":output.outputs[0].text,
    })

os.makedirs(args.save_path, exist_ok=True)

ds_results = Dataset.from_list(results).save_to_disk(args.save_path)

Check Code:

uuid_response = {}

ds = load_from_disk("0", keep_in_memory=True)
for entry in tqdm(ds):
    uuid = entry.get("uuid")
    response = entry.get("response")
    uuid_response[uuid] = response

ds = load_from_disk("1", keep_in_memory=True)
for entry in tqdm(ds):
    uuid = entry.get("uuid")
    response = entry.get("response")
    assert(uuid_response[uuid] != response)

got an AssertionError.

@Wangmerlyn
Copy link

save_to_disk

I think this is reasonable. Since you didn't specify a random seed, the random seed is default to 0. Despite you ran the same code on different gpus, they are using the same random seed so the randomized behavior is the same.
If you want to generate data with more diversity(different gpu different results), you can set the random seed differently.
Hope this helps.😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants