Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into upstream_merge_24_1…
Browse files Browse the repository at this point in the history
…2_09
  • Loading branch information
gshtras committed Dec 9, 2024
2 parents c9f5c24 + b63ba84 commit c324ea8
Show file tree
Hide file tree
Showing 37 changed files with 1,698 additions and 1,485 deletions.
3 changes: 2 additions & 1 deletion Dockerfile.neuron
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# default base image
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.0-ubuntu20.04"
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04"

FROM $BASE_IMAGE

Expand Down
14 changes: 12 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,20 @@ void reshape_and_cache_flash(
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
int num_tokens = key.size(0);
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(1);
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines >= 0.0.43, < 0.1
xgrammar >= 0.1.5; platform_machine == "x86_64"
xgrammar >= 0.1.6; platform_machine == "x86_64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs
Expand Down
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@e096d6f
25 changes: 2 additions & 23 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#
# This file is autogenerated by pip-compile with Python 3.9
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile requirements-test.in
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
#
absl-py==2.1.0
# via rouge-score
Expand All @@ -27,10 +27,6 @@ anyio==4.6.2.post1
# via httpx
argcomplete==3.5.1
# via datamodel-code-generator
async-timeout==4.0.3
# via
# aiohttp
# redis
attrs==24.2.0
# via
# aiohttp
Expand Down Expand Up @@ -111,10 +107,6 @@ email-validator==2.2.0
# via pydantic
evaluate==0.4.3
# via lm-eval
exceptiongroup==1.2.2
# via
# anyio
# pytest
fastrlock==0.8.2
# via cupy-cuda12x
filelock==3.16.1
Expand Down Expand Up @@ -165,8 +157,6 @@ idna==3.10
# httpx
# requests
# yarl
importlib-resources==6.4.5
# via matplotlib
inflect==5.6.2
# via datamodel-code-generator
iniconfig==2.0.0
Expand Down Expand Up @@ -518,12 +508,6 @@ timm==1.0.11
# via -r requirements-test.in
tokenizers==0.20.3
# via transformers
toml==0.10.2
# via datamodel-code-generator
tomli==2.0.2
# via
# black
# pytest
torch==2.5.1
# via
# -r requirements-test.in
Expand Down Expand Up @@ -567,12 +551,9 @@ typepy[datetime]==1.3.2
# tabledata
typing-extensions==4.12.2
# via
# anyio
# black
# huggingface-hub
# librosa
# mistral-common
# multidict
# pydantic
# pydantic-core
# torch
Expand All @@ -590,8 +571,6 @@ xxhash==3.5.0
# evaluate
yarl==1.17.1
# via aiohttp
zipp==3.20.2
# via importlib-resources
zstandard==0.23.0
# via lm-eval

Expand Down
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=772, total_tokens=782)
completion_tokens=10, prompt_tokens=775, total_tokens=785)

message = choice.message
message = chat_completion.choices[0].message
Expand Down Expand Up @@ -181,7 +181,7 @@ async def test_single_chat_session_image_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=772, total_tokens=782)
completion_tokens=10, prompt_tokens=775, total_tokens=785)

message = choice.message
message = chat_completion.choices[0].message
Expand Down
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_vision_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
assert len(embeddings["data"]) == 1
assert len(embeddings["data"][0]["embedding"]) == 3072
assert embeddings["usage"]["completion_tokens"] == 0
assert embeddings["usage"]["prompt_tokens"] == 762
assert embeddings["usage"]["total_tokens"] == 762
assert embeddings["usage"]["prompt_tokens"] == 765
assert embeddings["usage"]["total_tokens"] == 765
49 changes: 33 additions & 16 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# yapf: enable
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights)
from vllm.lora.punica import PunicaWrapper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -48,11 +48,12 @@
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
CUDA_DEVICES = [
# TODO: Modify this based on platform
DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

# We will launch different triton kernels between the prefill and decode
#For GPU, we will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]

Expand Down Expand Up @@ -192,9 +193,18 @@ def create_random_inputs(
return inputs, index_mapping, prompt_mapping


def check_punica_wrapper(punica_wrapper) -> bool:
if current_platform.is_cuda_alike():
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

return type(punica_wrapper) is PunicaWrapperGPU
else:
return False


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
Expand All @@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:

torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
Expand Down Expand Up @@ -296,7 +307,7 @@ def create_random_embedding_layer():
# @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
Expand All @@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
Expand Down Expand Up @@ -432,7 +444,7 @@ def create_random_embedding_layer():

@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
Expand All @@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
Expand Down Expand Up @@ -563,15 +576,16 @@ def _pretest():

@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
Expand Down Expand Up @@ -675,15 +689,16 @@ def create_random_linear_replicated_layer():
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
Expand Down Expand Up @@ -797,15 +812,16 @@ def create_random_linear_parallel_layer():
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
Expand Down Expand Up @@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seed = 0
current_platform.seed_everything(seed)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
Expand Down
Loading

0 comments on commit c324ea8

Please sign in to comment.