From c690357928fd2812f450bfb0c3629a816f5e9a55 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 9 Dec 2024 08:27:10 -0800 Subject: [PATCH 01/12] [V1] Fix Detokenizer loading in `AsyncLLM` (#10997) Signed-off-by: Roger Wang --- vllm/v1/engine/async_llm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4ef372fd8464b..0bcccda2bf329 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -65,7 +65,12 @@ def __init__( input_registry) # Detokenizer (converts EngineCoreOutputs --> RequestOutput). - self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) + self.detokenizer = Detokenizer( + tokenizer_name=vllm_config.model_config.tokenizer, + tokenizer_mode=vllm_config.model_config.tokenizer_mode, + trust_remote_code=vllm_config.model_config.trust_remote_code, + revision=vllm_config.model_config.tokenizer_revision, + ) # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_client( From e691b26f6fae5a3a1c220d15f20de83c7d78ed51 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 9 Dec 2024 11:44:27 -0500 Subject: [PATCH 02/12] [Core] Require xgrammar >= 0.1.6 (#11021) Signed-off-by: Russell Bryant --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 72fb020a82c4e..112528880c0ac 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 -xgrammar >= 0.1.5; platform_machine == "x86_64" +xgrammar >= 0.1.6; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs From aea2fc38c3b31b9a8ea7d1cffb8f37a2da6f6075 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 10 Dec 2024 01:24:46 +0800 Subject: [PATCH 03/12] [Platform] Move `async output` check to platform (#10768) Signed-off-by: wangxiyuan --- vllm/config.py | 17 +++-------------- vllm/platforms/cpu.py | 6 +++++- vllm/platforms/cuda.py | 12 +++++++++++- vllm/platforms/hpu.py | 6 +++++- vllm/platforms/interface.py | 11 +++++++++++ vllm/platforms/neuron.py | 6 +++++- vllm/platforms/openvino.py | 6 +++++- vllm/platforms/rocm.py | 12 +++++++++++- vllm/platforms/tpu.py | 6 +++++- vllm/platforms/xpu.py | 6 +++++- 10 files changed, 66 insertions(+), 22 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 7fbe04eaaf4f8..29f0839dcabba 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -513,11 +513,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid - if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"): + if not current_platform.is_async_output_supported(self.enforce_eager): logger.warning( - "Async output processing is only supported for CUDA, TPU, XPU " - "and HPU." - "Disabling it for other platforms.") + "Async output processing is not supported on the " + "current platform type %s.", current_platform.device_type) self.use_async_output_proc = False return @@ -527,16 +526,6 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - # Reminder: Please update docs/source/usage/compatibility_matrix.rst - # If the feature combo become valid - if device_config.device_type == "cuda" and self.enforce_eager: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - self.use_async_output_proc = not self.enforce_eager - return - # Async postprocessor is not necessary with embedding mode # since there is no token generation if self.task == "embedding": diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 680ee74129739..e5142b985d1f2 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import psutil import torch @@ -37,6 +37,10 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_device_total_memory(cls, device_id: int = 0) -> int: return psutil.virtual_memory().total + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 846a1869da228..edaf377b501df 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,7 +4,7 @@ import os from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Callable, List, TypeVar +from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar import pynvml import torch @@ -88,6 +88,16 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + @classmethod def is_full_nvlink(cls, device_ids: List[int]) -> bool: raise NotImplementedError diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 10aaa6d54962c..7f22bee3eaa74 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -20,6 +20,10 @@ class HpuPlatform(Platform): def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: return _Backend.HPU_ATTN + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0be7df7941b8b..db06d2c18e681 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -6,11 +6,15 @@ import numpy as np import torch +from vllm.logger import init_logger + if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None +logger = init_logger(__name__) + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -147,6 +151,13 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: """Get the total memory of a device in bytes.""" raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + """ + Check if the current platform supports async output. + """ + raise NotImplementedError + @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 87655ea198303..1e5c4bddfa24f 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from .interface import Platform, PlatformEnum @@ -18,6 +18,10 @@ class NeuronPlatform(Platform): def get_device_name(cls, device_id: int = 0) -> str: return "neuron" + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 29b61e955d9ab..e0f8e8b4b49fe 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -37,6 +37,10 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_device_name(self, device_id: int = 0) -> str: return "openvino" + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def inference_mode(self): return torch.inference_mode(mode=True) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3c14fbc179f69..66674e3ebe91f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,6 +1,6 @@ import os from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -72,6 +72,16 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index b138f7e1c54c5..10d874349f36b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -35,6 +35,10 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 9665786f4c499..11dbd04d55671 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -41,6 +41,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @staticmethod def inference_mode(): return torch.no_grad() From 25b79d9fd38e2c53ce281be23241d8939ec7320c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 9 Dec 2024 12:33:41 -0500 Subject: [PATCH 04/12] [V1] Input Batch Relocation (#10962) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- vllm/v1/worker/gpu_input_batch.py | 280 +++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 273 +--------------------------- 2 files changed, 283 insertions(+), 270 deletions(-) create mode 100644 vllm/v1/worker/gpu_input_batch.py diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py new file mode 100644 index 0000000000000..457784bb0287c --- /dev/null +++ b/vllm/v1/worker/gpu_input_batch.py @@ -0,0 +1,280 @@ +# Datastructures defining an input batch + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set + +import numpy as np +import torch + +from vllm.multimodal import MultiModalKwargs +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.multimodal.inputs import PlaceholderRange + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8d964a722f60..7f95be06188e3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ import gc import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import numpy as np import torch @@ -15,16 +14,16 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: - from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) @@ -609,269 +608,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: if batch_size <= size: return size return None - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 From edc4fa31888b4a41060acb7b16250540f051ad59 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Mon, 9 Dec 2024 11:46:58 -0800 Subject: [PATCH 05/12] [ci/build] Recompile CI dependencies list with Python 3.12 (#11013) Signed-off-by: kevin --- requirements-test.txt | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 19369254dbe26..38a064bca449a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # -# pip-compile requirements-test.in +# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt # absl-py==2.1.0 # via rouge-score @@ -27,10 +27,6 @@ anyio==4.6.2.post1 # via httpx argcomplete==3.5.1 # via datamodel-code-generator -async-timeout==4.0.3 - # via - # aiohttp - # redis attrs==24.2.0 # via # aiohttp @@ -111,10 +107,6 @@ email-validator==2.2.0 # via pydantic evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # pytest fastrlock==0.8.2 # via cupy-cuda12x filelock==3.16.1 @@ -165,8 +157,6 @@ idna==3.10 # httpx # requests # yarl -importlib-resources==6.4.5 - # via matplotlib inflect==5.6.2 # via datamodel-code-generator iniconfig==2.0.0 @@ -518,12 +508,6 @@ timm==1.0.11 # via -r requirements-test.in tokenizers==0.20.3 # via transformers -toml==0.10.2 - # via datamodel-code-generator -tomli==2.0.2 - # via - # black - # pytest torch==2.5.1 # via # -r requirements-test.in @@ -567,12 +551,9 @@ typepy[datetime]==1.3.2 # tabledata typing-extensions==4.12.2 # via - # anyio - # black # huggingface-hub # librosa # mistral-common - # multidict # pydantic # pydantic-core # torch @@ -590,8 +571,6 @@ xxhash==3.5.0 # evaluate yarl==1.17.1 # via aiohttp -zipp==3.20.2 - # via importlib-resources zstandard==0.23.0 # via lm-eval From 3b61cb450d899dc423feb264c297d4d18d701678 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 9 Dec 2024 12:38:46 -0800 Subject: [PATCH 06/12] [V1] Further reduce CPU overheads in flash-attn (#10989) Signed-off-by: Woosuk Kwon --- csrc/cache_kernels.cu | 14 ++++++++++++-- vllm/v1/attention/backends/flash_attn.py | 21 ++++++++++++++++----- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43c..8a95279f9a25a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -307,10 +307,20 @@ void reshape_and_cache_flash( torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] const std::string& kv_cache_dtype, const double k_scale, const double v_scale) { - int num_tokens = key.size(0); + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(1); diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d37989055c2e5..251a103e60f06 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -138,14 +138,25 @@ def forward( # Profiling run. return output - num_actual_tokens = attn_metadata.num_actual_tokens + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + num_actual_tokens = attn_metadata.num_actual_tokens # Reshape the input keys and values and store them in the cache. - key_cache = kv_cache[0] - value_cache = kv_cache[1] + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. + key_cache, value_cache = kv_cache.unbind(0) torch.ops._C_cache_ops.reshape_and_cache_flash( - key[:num_actual_tokens], - value[:num_actual_tokens], + key, + value, key_cache, value_cache, attn_metadata.slot_mapping, From ca871491edb0fba11fe9aa94300bd8d282fa29e1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 10 Dec 2024 04:54:44 +0800 Subject: [PATCH 07/12] [Misc][LoRA] Abstract PunicaWrapper (#10955) Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 49 +- vllm/lora/layers.py | 7 +- vllm/lora/models.py | 8 +- vllm/lora/punica.py | 725 -------------------- vllm/lora/punica_wrapper/__init__.py | 7 + vllm/lora/punica_wrapper/punica_base.py | 480 +++++++++++++ vllm/lora/punica_wrapper/punica_gpu.py | 358 ++++++++++ vllm/lora/punica_wrapper/punica_selector.py | 14 + vllm/lora/punica_wrapper/utils.py | 159 +++++ 9 files changed, 1058 insertions(+), 749 deletions(-) delete mode 100644 vllm/lora/punica.py create mode 100644 vllm/lora/punica_wrapper/__init__.py create mode 100644 vllm/lora/punica_wrapper/punica_base.py create mode 100644 vllm/lora/punica_wrapper/punica_gpu.py create mode 100644 vllm/lora/punica_wrapper/punica_selector.py create mode 100644 vllm/lora/punica_wrapper/utils.py diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index a113e3f7abc1e..fb8c0b2a7ba26 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -28,7 +28,7 @@ # yapf: enable from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, PackedLoRALayerWeights) -from vllm.lora.punica import PunicaWrapper +from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -48,11 +48,12 @@ torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } -CUDA_DEVICES = [ +# TODO: Modify this based on platform +DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -# We will launch different triton kernels between the prefill and decode +#For GPU, we will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -192,9 +193,18 @@ def create_random_inputs( return inputs, index_mapping, prompt_mapping +def check_punica_wrapper(punica_wrapper) -> bool: + if current_platform.is_cuda_alike(): + from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU + + return type(punica_wrapper) is PunicaWrapperGPU + else: + return False + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: @@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -296,7 +307,7 @@ def create_random_embedding_layer(): # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings_with_new_embeddings(dist_init, num_loras, device, @@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -432,7 +444,7 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, @@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -563,7 +576,7 @@ def _pretest(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_replicated(dist_init, num_loras, device, stage, @@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -675,7 +689,7 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, @@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -797,7 +812,7 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, @@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seed = 0 current_platform.seed_everything(seed) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3e9c2ceb83eac..38cb846578d5c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -17,7 +17,6 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_gather) from vllm.distributed.utils import divide -from vllm.lora.punica import PunicaWrapper # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -33,7 +32,7 @@ VocabParallelEmbedding) if TYPE_CHECKING: - pass + from vllm.lora.punica_wrapper import PunicaWrapperBase def _get_lora_device(base_layer: nn.Module) -> torch.device: @@ -115,9 +114,9 @@ def set_lora( def set_mapping( self, - punica_wrapper: PunicaWrapper, + punica_wrapper, ): - self.punica_wrapper: PunicaWrapper = punica_wrapper + self.punica_wrapper: PunicaWrapperBase = punica_wrapper @classmethod def can_replace_layer( diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9855b57d0c9c9..49cd9f0c236ad 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -21,7 +21,7 @@ LinearScalingRotaryEmbeddingWithLora, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.punica import PunicaWrapper +from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) @@ -331,9 +331,9 @@ def __init__( self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None - self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, - max_batches=self.max_num_seqs, - device=self.device) + self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py deleted file mode 100644 index 563d1181d6fcb..0000000000000 --- a/vllm/lora/punica.py +++ /dev/null @@ -1,725 +0,0 @@ -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union - -import torch - -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.lora.ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.sgmv_shrink import sgmv_shrink - -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext - - -def compute_meta( - token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: - """ - Get the information required for the sgmv kernel. With the features: - 1. If consecutive requests in the batch use the same LoRA, this function - will combine them into a single request, improving sgmv kernel inference - performance. - 2. At the beginning of each prefill stage inference, recalculations are - needed based on the input, but only once. - """ - - lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( - token_lora_tensor, return_counts=True) - cum_result = torch.cumsum(seq_length_tensor, dim=0) - b_seq_start_tensor = torch.zeros_like(seq_length_tensor) - b_seq_start_tensor[1:].copy_(cum_result[:-1]) - max_length = seq_length_tensor.max().item() - token_nums = seq_length_tensor.sum().item() - batch_size = lora_indices_tensor.size(0) - no_lora = False - # -1 means no lora should be applied. Use `no_lora` to determine whether - # the current step requires LoRA. If LoRA is not needed, the prefill stage - # does not need to launch the triton kernel, which can improve performance - if batch_size == 1 and lora_indices_tensor == -1: - no_lora = True - return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, no_lora) - - -# TODO see if this can be vectorized -def convert_mapping( - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - device: torch.device, - long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: - """Converts LoRAMapping to index tensors. - - Args: - mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. - max_loras: Maximum number of LoRAs. - vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. - - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - LoRA indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - LoRA indices for sampler. For generation, this will be the - same as base_indicies. For prefill, this will map requests - to LoRA indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to LoRA indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. - indices_len: List of lengths of the above tensors. It contains - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - embedding_indices = index_mapping_indices.copy() - lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - lora_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, - lora_indices, - embedding_indices, - ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor(prompt_mapping, - dtype=torch.long, - device=device) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( - sampler_indices_padded * len(sampler_indices_padded)) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], - sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1], - ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - else: - # If long_lora doesn't exist,append None - indices_len.append(None) - - return ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_indices, - indices_len, - ) - - -class PunicaWrapper: - """ - PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for - Multi-LoRA, and to provide the interface for the punica kernel. - """ - - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str]): - self._token_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices_padded = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._embeddings_indices = torch.empty(2, - max_num_batched_tokens, - dtype=torch.long, - device=device) - self._long_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - - # 5 is the number of indicies tensors. - # base_indices, sampler_indices, sampler_indices_padded, - # embeddings_indices,long_lora_indices - self.indices_len: List[Optional[int]] = [None] * 5 - # these attributes are the information required for sgmv kernel - self._seq_start_locs = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._seq_lengths = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._lora_indices_per_batch = torch.empty(max_batches, - dtype=torch.long, - device=device) - self.device: torch.device = device - self.max_length: int = 0 - self.token_nums: int = 0 - self.batch_size: int = -1 - self.is_prefill = False - self.no_lora = False - - def update_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) - if mapping.is_prefill: - # Update metadata required for prefill-related operators. - self._update_prefill_metada(self.token_lora_indices) - self.is_prefill = True - else: - self.is_prefill = False - - def _update_base_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_offsets_tensor, - indices_len, - ) = convert_mapping( - mapping, - lora_index_to_id, - max_loras, - vocab_size, - extra_vocab_size, - self.device, - long_lora_context, - ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) - if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self._long_lora_indices.zero_() - self.indices_len[:] = indices_len - - def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: - - (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, - no_lora) = compute_meta(token_lora_tensor) - - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( - b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( - lora_indices_tensor) - self.batch_size = batch_size - self.max_length = max_length - self.token_nums = token_nums - self.no_lora = no_lora - - @property - def prefill_metadata( - self - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: - """ - This property provides a convenient way to access the necessary - metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions. - 2. seq_lengths: Tensor of sequence lengths. - 3. lora_indices_per_batch: Tensor of lora indices, and an index of - -1 means no lora should be applied. - 4. batch_size: Batch size after clustering identical lora indices. - 5. max_length: The maximum sequence length in the batch. - 6. token_nums: The token numbers in the batch. - """ - return (self._seq_start_locs[:self.batch_size], - self._seq_lengths[:self.batch_size], - self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length, self.token_nums) - - @property - def token_lora_indices(self) -> torch.Tensor: - """ - This property provides the lora indices corresponding to each token - in the batch. An index of -1 means no lora should be applied. - """ - token_lora_len = self.indices_len[0] - return self._token_lora_indices[:token_lora_len] - - @property - def sampler_indices(self) -> torch.Tensor: - """ - This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA. - """ - sampler_indices_len = self.indices_len[1] - return self._sampler_indices[:sampler_indices_len] - - @property - def sampler_indices_padded(self) -> torch.Tensor: - """ - This property provides access to padded sampler indices. - """ - indices_padded_len = self.indices_len[2] - return self._sampler_indices_padded[:indices_padded_len] - - @property - def embeddings_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA. - """ - embeddings_indices_len = self.indices_len[3] - return self._embeddings_indices[:, :embeddings_indices_len] - - @property - def long_lora_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora. - """ - long_lora_len = self.indices_len[4] - return self._long_lora_indices[:long_lora_len] - - def _shrink_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_shrink( - x, - w_t_all, - y, - *self.prefill_metadata, - scale, - ) - - def _shrink_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - - def _expand_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand( - x, - w_t_all, - y, - *self.prefill_metadata, - add_input, - ) - - def _expand_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool, - ): - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) - - def _expand_slice_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand_slice( - x, - w_t_all, - y, - *self.prefill_metadata, - y_offset, - y_slice_size, - add_input, - ) - - def _expand_slice_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool, - ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_input) - - def _apply_expand(self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool = True): - """ - Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` - computation, which is suitable for the - GEMM of lora'b. - """ - - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) - - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: Tuple[int, ...], - lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left:offset_left + slice] += bias - offset_left += slice - - return output.view_as(org_output) - - def _apply_shrink( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) - shrink_fun(y, x, w_t_all, scale) - y = y.view_as(y_org) - - def add_shrink( - self, - y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, - ): - """ - Performs GEMM for multiple slices of lora_a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += (x @ lora_a_stacked[i]) * scale - - Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors - x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights - scale (float): Scaling factor for the operation - """ - - x = x.view(-1, x.shape[-1]) - # TODO fuse these kernels - for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) - - def add_expand( - self, - y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], - offset_start: int = 0, - add_input=True, - ) -> None: - """ - Performs GEMM and bias addition for multiple slices of lora_b. - - Semantics: - for i in range(len(lora_b_stacked)): - slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] - offset += slice - - Args: - y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): - bias's weight - output_slices (Tuple[int, ...]): Every slice's size - add_input (bool): Defaults to True. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - offset_left = offset_start - if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) - for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_input=add_input, - ) - offset_left += output_slices[slice_idx] - y = y.view_as(y_org) - - def add_lora_embedding( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_input: bool = True, - ): - """ - Applies lora specifically for VocabParallelEmbeddingWithLoRA. - - Semantics: - y += x @ lora_b_stacked - - Args: - y (torch.Tensor): Output tensor. - x (torch.Tensor): Input tensor. - lora_b_stacked (torch.Tensor): lora_b's weights. - add_input (bool): Default to True. - - """ - - # Embedding layer only need expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) - expand_fun(y, x, lora_b_stacked, add_input) - - def add_lora_linear( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - scale: float, - output_slices: Tuple[int, ...], - *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: - """ - Applicable to linear-related lora. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += ( - x[i].unsqueeze(0) - @ lora_a_stacked[indices[i], layer_idx, :, :] - @ lora_b_stacked[indices[i], layer_idx, :, :] - * scale - ).squeeze(0)+lora_bias_stacked[i] - - Args: - y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. - scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. - """ - - assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) - - if buffer is None: - r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default ,refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) - self.add_shrink(buffer, x, lora_a_stacked, scale) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_input=True) - - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None) -> None: - """ - Applies lora specifically for LogitsProcessorWithLoRA. - - Semantics: - buffer = (x @ lora_a_stacked) * scale - y += buffer @ lora_b_stacked - - Args: - y (torch.Tensor): Output tensor. - x (torch.Tensor): Input tensor. - lora_a_stacked (torch.Tensor): lora_a's weights. - lora_b_stacked (torch.Tensor):lora_b's weights. - scale (float): Scaling factor. - buffer (Optional[torch.Tensor]):Default to None. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - r = lora_b_stacked.size(-1) - if buffer is None: - # We set the buffer to be float32 by default ,refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - # LogitsProcessorWithLoRA always using bgmv. - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) - y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/__init__.py b/vllm/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000000000..48ada3926ea46 --- /dev/null +++ b/vllm/lora/punica_wrapper/__init__.py @@ -0,0 +1,7 @@ +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase +from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper + +__all__ = [ + "PunicaWrapperBase", + "get_punica_wrapper", +] diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py new file mode 100644 index 0000000000000..0a5a84bdd8deb --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -0,0 +1,480 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from .utils import compute_meta, convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +class PunicaWrapperABC(ABC): + """ + PunicaWrapper ABC. + """ + + @abstractmethod + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs, + ) -> None: + """ + Update the lora-related metadata + """ + raise NotImplementedError + + @abstractmethod + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + """ + Performs GEMM for multiple slices of lora_a. + """ + + raise NotImplementedError + + @abstractmethod + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs, + ) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs, + ) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA, + and this layer only requires the expand operation. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + """ + + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + """ + raise NotImplementedError + + +class PunicaWrapperBase(PunicaWrapperABC): + """ + PunicaWrapperBase is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.device: torch.device = device + self.max_length: int = 0 + self.token_nums: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + self.device, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.token_nums = token_nums + self.no_lora = no_lora + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length, self.token_nums) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA. + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora. + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + @abstractmethod + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs) -> None: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_input (bool): Defaults to True. + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + and this layer only requires the expand operation. + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_input (bool): Default to True. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py new file mode 100644 index 0000000000000..b2af29de129ce --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -0,0 +1,358 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Callable, Optional, Tuple, Union, final + +import torch + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.sgmv_shrink import sgmv_shrink + +from .punica_base import PunicaWrapperBase + + +@final +class PunicaWrapperGPU(PunicaWrapperBase): + """ + PunicaWrapperGPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica triton kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_input) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_input (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_input=add_input, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_input (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_input) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_input=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py new file mode 100644 index 0000000000000..df6c1bdc7dd71 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -0,0 +1,14 @@ +from vllm.platforms import current_platform +from vllm.utils import print_info_once + +from .punica_base import PunicaWrapperBase + + +def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: + if current_platform.is_cuda_alike(): + # Lazy import to avoid ImportError + from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU + print_info_once("Using PunicaWrapperGPU.") + return PunicaWrapperGPU(*args, **kwargs) + else: + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py new file mode 100644 index 0000000000000..7360c8c09e3ac --- /dev/null +++ b/vllm/lora/punica_wrapper/utils.py @@ -0,0 +1,159 @@ +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + token_nums = seq_length_tensor.sum().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + device: torch.device, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) + prompt_mapping_tensor = torch.tensor(prompt_mapping, + dtype=torch.long, + device=device) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) From a811dd660856a5c222a1447fe1d93deccbc162fd Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 10 Dec 2024 04:55:10 +0800 Subject: [PATCH 08/12] [Model] merged input processor for Phi-3-Vision models (#10977) Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung --- tests/entrypoints/openai/test_vision.py | 4 +- .../openai/test_vision_embedding.py | 4 +- .../mm_processor_kwargs/test_phi3v.py | 136 ++------ tests/multimodal/test_processor_kwargs.py | 169 +++++----- vllm/inputs/registry.py | 4 +- vllm/model_executor/models/phi3v.py | 298 +++++------------- vllm/multimodal/processing.py | 29 +- 7 files changed, 235 insertions(+), 409 deletions(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 157d873a75b4d..a0b6edd566561 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -89,7 +89,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=772, total_tokens=782) + completion_tokens=10, prompt_tokens=775, total_tokens=785) message = choice.message message = chat_completion.choices[0].message @@ -181,7 +181,7 @@ async def test_single_chat_session_image_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=772, total_tokens=782) + completion_tokens=10, prompt_tokens=775, total_tokens=785) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index d0c43b47bf0af..425f2a10ec855 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -95,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, assert len(embeddings["data"]) == 1 assert len(embeddings["data"][0]["embedding"]) == 3072 assert embeddings["usage"]["completion_tokens"] == 0 - assert embeddings["usage"]["prompt_tokens"] == 762 - assert embeddings["usage"]["total_tokens"] == 762 + assert embeddings["usage"]["prompt_tokens"] == 765 + assert embeddings["usage"]["total_tokens"] == 765 diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py index 60a8f63eb5faa..c16192a1e1438 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py @@ -2,12 +2,10 @@ from typing import Optional import pytest -import torch -from transformers import AutoImageProcessor, AutoTokenizer +from transformers import AutoTokenizer -from vllm.inputs import InputContext, token_inputs +from vllm.inputs import InputContext, InputProcessingContext from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID -from vllm.multimodal import MultiModalRegistry from .....conftest import _ImageAssets from ....utils import build_model_context @@ -17,15 +15,9 @@ # Wrap lazy imports to avoid initializing CUDA during test collection @pytest.fixture() -def input_processor_for_phi3v(): - from vllm.model_executor.models.phi3v import input_processor_for_phi3v - return input_processor_for_phi3v - - -@pytest.fixture() -def dummy_data_for_phi3v(): - from vllm.model_executor.models.phi3v import dummy_data_for_phi3v - return dummy_data_for_phi3v +def processor_for_phi3v(): + from vllm.model_executor.models.phi3v import Phi3VProcessor + return Phi3VProcessor @pytest.fixture() @@ -34,53 +26,6 @@ def get_max_phi3v_image_tokens(): return get_max_phi3v_image_tokens -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops", [4, 16, None]) -def test_input_mapper_override(model: str, image_assets: _ImageAssets, - num_crops: Optional[int]): - """Ensure that the [default] input mapper handles num_crops properly.""" - # We pass the processor kwargs here since for this model, we fall back to - # the default mapper; this will fall back to the HF mapper and forward - # mm_processor_kwargs to it. - mm_processor_kwargs = { - "num_crops": num_crops - } if num_crops is not None else {} - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=mm_processor_kwargs, - ) - - hf_processor = AutoImageProcessor.from_pretrained(model, - trust_remote_code=True, - **mm_processor_kwargs) - - mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(ctx.model_config) - - image = image_assets[0].pil_image - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ) - - vllm_result = mm_registry.map_input( - ctx.model_config, - {"image": image}, - ) - - assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) - assert torch.all( - hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) - - # For pixel values, the second axis should be the num_crops + 1 - # for the rescaled original image. The default value in VLLM falls - # back to the HF config, which is why we compare to the processor num_crops - assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) - assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 - - @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("num_crops,expected_max_tokens", [ (4, 781), @@ -112,48 +57,20 @@ def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, @pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ - (4, 781, 1), - (4, 781, 2), - (16, 2653, 1), - (16, 2653, 2), -]) -def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, - toks_per_img: int, num_imgs: int): - """Ensure dummy_data_for_phi3v handles num_crops properly.""" - # Same as the previous test - don't initialize mm_processor_kwargs - # in this test and assume that the kwargs will be correctly expanded by - # the partial when calling the dummy data func. - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=None, - ) - - dummy_data = dummy_data_for_phi3v( - ctx=ctx, - seq_len=8192, # Should be bigger than num_imgs * toks_per_img - mm_counts={"image": num_imgs}, - num_crops=num_crops, - ) - sequence_data = dummy_data.seq_data - # Ensure we have the right number of placeholders per num_crops size - img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) - assert img_tok_count == toks_per_img * num_imgs - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ - (4, 757, 1), - (4, 757, 2), - (16, 1921, 1), - (16, 1921, 2), -]) -def test_input_processor_override(input_processor_for_phi3v, - image_assets: _ImageAssets, model: str, - num_crops: int, expected_toks_per_img: int, - num_imgs: int): +@pytest.mark.parametrize( + "num_crops,expected_toks_per_img,num_imgs", + [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), + # the default num_crops of phi-3.5-vision is 4 + (None, 757, 2), + (None, 757, 2), + ]) +def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets, + model: str, num_crops: Optional[int], + expected_toks_per_img: int, num_imgs: int): """Ensure input_processor_for_phi3v handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by @@ -163,19 +80,20 @@ def test_input_processor_override(input_processor_for_phi3v, tokenizer_name=model, trust_remote_code=True, ) - tokenizer = AutoTokenizer.from_pretrained(model) + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + ctx = InputProcessingContext(ctx.model_config, tokenizer) # Build the image str / prompt based on the number of images we pass img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" images = [image_assets[0].pil_image] * num_imgs - inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), - prompt=prompt, - multi_modal_data={"image": images}) + mm_data = {"image": images} + mm_processor_kwargs = {} + if num_crops is not None: + mm_processor_kwargs = {"num_crops": num_crops} - processed_inputs = input_processor_for_phi3v(ctx, - inputs, - num_crops=num_crops) + processor = processor_for_phi3v(ctx) + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index e6c8793989e13..d141cdf1f083b 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -15,13 +15,13 @@ # Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" # Used for tests that need a multimodal model -MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" +MULTIMODAL_MODEL_ID = "OpenGVLab/InternVL2-2B" # For mm_processor_kwargs - we test overrides by defining mocks for each place # it is used, and ensuring that we can pass processor kwargs an override value # to receive the intended result for things like sequence length etc. -DEFAULT_NUM_CROPS = 4 -NUM_CROPS_OVERRIDE = 16 +DEFAULT_MAX_DYNAMIC_PATCH = 6 +MAX_DYNAMIC_PATCH_OVERRIDE = 4 # Mocks for all of the places that we use the mm_processor_kwargs @@ -33,10 +33,11 @@ def use_processor_mock(): def custom_processor(ctx: InputContext, inputs: DecoderOnlyInputs, *, - num_crops=DEFAULT_NUM_CROPS): + max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH): # For testing purposes, we don't worry about the prompt - return token_inputs(prompt_token_ids=[], - mm_processor_kwargs={"num_crops": num_crops}) + return token_inputs( + prompt_token_ids=[], + mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch}) with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): @@ -52,9 +53,9 @@ def custom_dummy_data_factory(self, seq_len: int, mm_counts: Mapping[str, int], *, - num_crops=DEFAULT_NUM_CROPS): + max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH): seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * max_dynamic_patch)) return DummyData(seq_data, None) with patch( @@ -65,15 +66,15 @@ def custom_dummy_data_factory(self, # Lazy import to avoid CUDA reinitialization error def mm_model_cls(): - from vllm.model_executor.models.phi3v import Phi3VForCausalLM + from vllm.model_executor.models.internvl import InternVLChatModel - return Phi3VForCausalLM + return InternVLChatModel # lambda whose signature matches max token calcs extra & mapper + extra kwargs -get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops -custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { - "pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) +get_max_dynamic_patch = lambda ctx, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: max_dynamic_patch # noqa: E501 +custom_mapper = lambda ctx, data, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: { # noqa: E501 + "pixel_values": torch.zeros(size=(1, max_dynamic_patch + 1, 3, 448, 448)) } @@ -88,27 +89,28 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs -def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): - """Get the init / inference kwargs and expected num_crops for this test.""" - # If we have a value for num_crops, pass the override value and make +def _get_max_dynamic_patch_info(init_max_dynamic_patch: int, + inference_max_dynamic_patch: int): + """Get the init / inference kwargs and expected max_dynamic_patch.""" + # If we have a value for max_dynamic_patch, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value - init_kwargs = None if init_num_crops is None else { - "num_crops": init_num_crops + init_kwargs = None if init_max_dynamic_patch is None else { + "max_dynamic_patch": init_max_dynamic_patch } - inference_kwargs = None if inference_num_crops is None else { - "num_crops": inference_num_crops + inference_kwargs = None if inference_max_dynamic_patch is None else { + "max_dynamic_patch": inference_max_dynamic_patch } - if inference_num_crops is not None: - expected_seq_count = inference_num_crops - elif init_num_crops is not None: - expected_seq_count = init_num_crops + if inference_max_dynamic_patch is not None: + expected_seq_count = inference_max_dynamic_patch + elif init_max_dynamic_patch is not None: + expected_seq_count = init_max_dynamic_patch else: - expected_seq_count = DEFAULT_NUM_CROPS + expected_seq_count = DEFAULT_MAX_DYNAMIC_PATCH return init_kwargs, inference_kwargs, expected_seq_count -def _get_processed_num_crops( +def _get_processed_max_dynamic_patch( processor: Callable[[ProcessorInputs], ProcessorInputs], inference_kwargs: Optional[Dict[str, int]], ) -> int: @@ -120,27 +122,30 @@ def _get_processed_num_crops( assert "type" in processed_inputs assert processed_inputs["type"] == "token" assert "mm_processor_kwargs" in processed_inputs - return processed_inputs["mm_processor_kwargs"]["num_crops"] + return processed_inputs["mm_processor_kwargs"]["max_dynamic_patch"] -@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ - (None, None), - (NUM_CROPS_OVERRIDE, None), - (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), -]) -def test_input_processor_kwargs(use_processor_mock, init_num_crops, - inference_num_crops): +@pytest.mark.parametrize( + "init_max_dynamic_patch,inference_max_dynamic_patch", [ + (None, None), + (MAX_DYNAMIC_PATCH_OVERRIDE, None), + (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE), + ]) +def test_input_processor_kwargs(use_processor_mock, init_max_dynamic_patch, + inference_max_dynamic_patch): """Ensure input processors can use processor kwargs.""" dummy_registry = InputRegistry() - init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( - init_num_crops, inference_num_crops) + (init_kwargs, inference_kwargs, + expected_seq_count) = _get_max_dynamic_patch_info( + init_max_dynamic_patch, inference_max_dynamic_patch) ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = _get_processed_num_crops(processor, inference_kwargs) + max_dynamic_patch_val = _get_processed_max_dynamic_patch( + processor, inference_kwargs) - assert num_crops_val == expected_seq_count + assert max_dynamic_patch_val == expected_seq_count @pytest.mark.parametrize( @@ -165,18 +170,21 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor = dummy_registry.create_input_processor(ctx.model_config) # Should filter out the inference time kwargs - num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs) - assert num_crops_val == DEFAULT_NUM_CROPS + max_dynamic_patch_val = _get_processed_max_dynamic_patch( + processor, mm_processor_kwargs) + assert max_dynamic_patch_val == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the dummy data -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): +@pytest.mark.parametrize("max_dynamic_patch", + [None, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_dummy_data_kwarg_overrides(use_dummy_data_mock, max_dynamic_patch): """Ensure dummy data factories can use processor kwargs.""" - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + mm_processor_kwargs = None if max_dynamic_patch is None else { + "max_dynamic_patch": max_dynamic_patch } - expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH + if max_dynamic_patch is None else max_dynamic_patch) dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs) @@ -217,17 +225,20 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, # len is solely dependent on the value of the mm_processor_kwargs. dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + assert len( + dummy_data.seq_data.prompt_token_ids) == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the max token count per multimodal instance -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_max_tokens_kwarg_overrides(num_crops): +@pytest.mark.parametrize("max_dynamic_patch", + [None, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_max_tokens_kwarg_overrides(max_dynamic_patch): """Ensure max token calcs can use processor kwargs.""" - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + mm_processor_kwargs = None if max_dynamic_patch is None else { + "max_dynamic_patch": max_dynamic_patch } - expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH + if max_dynamic_patch is None else max_dynamic_patch) ctx = build_model_context(MULTIMODAL_MODEL_ID, task="generate", @@ -239,11 +250,11 @@ def test_max_tokens_kwarg_overrides(num_crops): mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", - {mm_model_cls(): get_num_crops}, + {mm_model_cls(): get_max_dynamic_patch}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( ctx.model_config) @@ -279,26 +290,29 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", - {mm_model_cls(): get_num_crops}, + {mm_model_cls(): get_max_dynamic_patch}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( ctx.model_config) - assert max_multimodal_tokens == DEFAULT_NUM_CROPS + assert max_multimodal_tokens == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the mapper -@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) -def test_default_mapper_with_processor_kwargs(image_assets, num_crops): +@pytest.mark.parametrize( + "max_dynamic_patch", + [DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_default_mapper_with_processor_kwargs(image_assets, max_dynamic_patch): """Ensure that the mapper processor kwargs can fall back to HF models.""" # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. - ctx = build_model_context(MULTIMODAL_MODEL_ID, - task="generate", - trust_remote_code=True, - mm_processor_kwargs={"num_crops": num_crops}, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context( + MULTIMODAL_MODEL_ID, + task="generate", + trust_remote_code=True, + mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch}, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) @@ -307,20 +321,22 @@ def test_default_mapper_with_processor_kwargs(image_assets, num_crops): mm_inputs = {"image": image} mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) - # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] - assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 + # pixel vals should have shape: [batch, max_dynamic_patch+1, ...] + assert mapped_inputs["pixel_values"].shape[1] == max_dynamic_patch + 1 -@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ - (None, None), - (NUM_CROPS_OVERRIDE, None), - (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), -]) -def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, - inference_num_crops): +@pytest.mark.parametrize( + "init_max_dynamic_patch,inference_max_dynamic_patch", [ + (None, None), + (MAX_DYNAMIC_PATCH_OVERRIDE, None), + (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE), + ]) +def test_custom_mapper_kwarg_overrides(image_assets, init_max_dynamic_patch, + inference_max_dynamic_patch): """Ensure custom mappers can use processor kwargs.""" - init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( - init_num_crops, inference_num_crops) + (init_kwargs, inference_kwargs, + expected_seq_count) = _get_max_dynamic_patch_info( + init_max_dynamic_patch, inference_max_dynamic_patch) ctx = build_model_context(MULTIMODAL_MODEL_ID, task="generate", @@ -335,7 +351,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_model_cls()) mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs, @@ -373,11 +389,12 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_model_cls()) # Should filter out the inference time kwargs mapped_inputs = mm_registry.map_input( ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs) - assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 + assert mapped_inputs["pixel_values"].shape[1] == ( + DEFAULT_MAX_DYNAMIC_PATCH + 1) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 646554c72481a..0dfed3b7e61bf 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -69,12 +69,12 @@ class InputProcessingContext(InputContext): tokenizer: AnyTokenizer """The tokenizer used to tokenize the inputs.""" - def get_hf_processor(self) -> ProcessorMixin: + def get_hf_processor(self, **kwargs) -> ProcessorMixin: return cached_get_processor( self.model_config.tokenizer, tokenizer=self.tokenizer, # Override the tokenizer with ours trust_remote_code=self.model_config.trust_remote_code, - ) + **kwargs) N = TypeVar("N", bound=Type[nn.Module]) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index eef23029a2aca..3c7854ce388ab 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -12,22 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools -import re -from functools import cached_property, lru_cache -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) -import numpy as np import torch import torch.nn as nn -from PIL import Image -from transformers import CLIPVisionConfig, PretrainedConfig +from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, + ProcessorMixin) from vllm.attention import AttentionMetadata -from vllm.config import ModelConfig, VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.config import VllmConfig +from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -36,12 +32,18 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors, PlaceholderRange -from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, + ModalityProcessingMetadata, + MultiModalDataDict, + MultiModalProcessingMetadata, + PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .clip import dummy_image_for_clip, dummy_seq_data_for_clip +from .clip import dummy_image_for_clip from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -303,231 +305,99 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 -def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): - target_height = int(np.ceil(height / padding_unit) * padding_unit) - top_padding = int((target_height - height) / 2) - bottom_padding = target_height - height - top_padding - padded_width = width - padded_height = height + top_padding + bottom_padding - return padded_width, padded_height - - -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 -def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): - transposed = False - if width < height: - width, height = height, width - transposed = True - - ratio = width / height - scale = 1 - while scale * np.ceil(scale / ratio) <= hd_num: - scale += 1 - scale -= 1 - - new_width = int(scale * 336) - new_height = int(new_width / ratio) - - padded_width, padded_height = _calc_padded_size(width=new_width, - height=new_height) - - if transposed: - padded_width, padded_height = padded_height, padded_width - - return padded_width, padded_height - - -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181 -def get_phi3v_image_feature_size( - hf_config: Dict[str, Any], - *, - input_height: int, - input_width: int, - num_crops: int, -) -> int: - if num_crops is None: - num_crops = hf_config.get("num_crops", 16) - new_width, new_height = _calc_hd_transform_size(width=input_width, - height=input_height, - hd_num=num_crops) - - return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \ - + (new_height // 336 + 1) * 12 - - def get_max_phi3v_image_tokens(ctx: InputContext, *, num_crops: Optional[int] = None): + mm_processor_kwargs = {} + if num_crops is not None: + mm_processor_kwargs["num_crops"] = num_crops - return get_phi3v_image_feature_size( - ctx.get_hf_image_processor_config(), - input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - num_crops=num_crops, + model_config = ctx.model_config + image_processor = cached_get_image_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs, + ) + + num_tokens = image_processor.calc_num_image_tokens_from_image_size( + width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) + return num_tokens -def dummy_data_for_phi3v(ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - *, - num_crops: Optional[int] = None): +def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext, + mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) - - seq_data, ranges = dummy_seq_data_for_clip( - CLIP_VIT_LARGE_PATCH14_336_CONFIG, - seq_len, - num_images, - image_token_id=_IMAGE_TOKEN_ID, - image_feature_size_override=image_feature_size, - ) - mm_data = dummy_image_for_clip( + data = dummy_image_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, num_images, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return DummyData(seq_data, mm_data, ranges) - + hf_processor = ctx.get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt") -@lru_cache -def _get_image_placeholder_token_id_candidates( - model_config: ModelConfig, - idx: int, -) -> List[List[int]]: - assert idx > 0 + return MultiModalKwargs(**hf_inputs) - tokenizer = cached_get_tokenizer(model_config.tokenizer) - # This is used when the image token is at the start of the string - start_candidate = tokenizer.encode(f"<|image_{idx}|>", - add_special_tokens=False) +def create_metadata_for_phi3v( + ctx: InputProcessingContext) -> MultiModalProcessingMetadata: + return { + "image": + ModalityProcessingMetadata(prompt_repls=[ + PromptReplacement(target=[_IMAGE_TOKEN_ID], + repl_unit=[_IMAGE_TOKEN_ID], + repl_count=get_max_phi3v_image_tokens(ctx)), + ]), + } - # This is used when the image token is in the middle of the string - # We need to get the token for "<", not "▁<" - # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json - a_token_id, = tokenizer.encode("a", add_special_tokens=False) - a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>", - add_special_tokens=False) - assert a_token_id == a_token_id_ - return [start_candidate, middle_candidate] +class Phi3VProcessor(BaseMultiModalProcessor): + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__( + ctx=ctx, + metadata=create_metadata_for_phi3v(ctx), + ) -def input_processor_for_phi3v(ctx: InputContext, - inputs: DecoderOnlyInputs, - *, - num_crops: Optional[int] = None): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - model_config = ctx.model_config - hf_config = ctx.get_hf_image_processor_config() - - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - w, h = image_data.size - image_feature_size = [ - get_phi3v_image_feature_size(hf_config, - input_width=w, - input_height=h, - num_crops=num_crops) - ] - image_data = [image_data] - elif is_list_of(image_data, Image.Image): - image_feature_size = [] - for image in image_data: - w, h = image.size - image_feature_size.append( - get_phi3v_image_feature_size(hf_config, - input_width=w, - input_height=h, - num_crops=num_crops)) - elif isinstance(image_data, torch.Tensor): - image_feature_size = [image_data.shape[0]] - image_data = [image_data] - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[0] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - prompt = inputs.get("prompt") - if prompt is None: - # for async server request, we assume prompt and its token_ids is always - # in correct format. And num_image_tags == len(image_data) always True. - image_idx = range(1, len(image_data) + 1) - new_prompt = None - else: - image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt))) - if prompt.count("<|image|>") > 0: - logger.warning("Please follow the prompt format that is " - "documented on HuggingFace which does not involve " - "repeating <|image|> tokens.") - elif (num_image_tags := len(image_idx)) > 1: - assert num_image_tags == len( - image_data), "The count of image_placeholder not match image's" - new_prompt = prompt - - prompt_token_ids = inputs["prompt_token_ids"].copy() - - # masked placeholder with image token id - for idx in image_idx: - candidates = _get_image_placeholder_token_id_candidates(model_config, - idx=idx) - - for candidate in candidates: - for i in range(len(prompt_token_ids) - len(candidate) + 1): - if prompt_token_ids[i:i + len(candidate)] == candidate: - prompt_token_ids[i:i + - len(candidate)] = ([_IMAGE_TOKEN_ID] * - len(candidate)) - break - - # merge consecutive tag ids - merged_token_ids: List[int] = [] - for is_placeholder, token_ids in itertools.groupby( - prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID): - if is_placeholder: - merged_token_ids.append(_IMAGE_TOKEN_ID) - else: - merged_token_ids.extend(list(token_ids)) - - # TODO: Move this to utils or integrate with clip. - new_token_ids: List[int] = [] - placeholder_ranges: List[PlaceholderRange] = [] - placeholder_idx = 0 - while merged_token_ids: - token_id = merged_token_ids.pop(0) - if token_id == _IMAGE_TOKEN_ID: - replacement_ids = repeat_and_pad_token( - _IMAGE_TOKEN_ID, - repeat_count=image_feature_size[placeholder_idx], - ) - placeholder_ranges.append({ - "offset": len(new_token_ids), - "length": len(replacement_ids) - }) - new_token_ids.extend(replacement_ids) - placeholder_idx += 1 - else: - new_token_ids.append(token_id) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": placeholder_ranges}) + def _get_hf_processor( + self, + *, + num_crops: Optional[int] = None, + ) -> ProcessorMixin: + if num_crops is not None: + return self.ctx.get_hf_processor(num_crops=num_crops) + return self.ctx.get_hf_processor() + + def _apply_hf_processor( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._apply_hf_processor( + prompt, mm_data, mm_processor_kwargs) + # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, + # which will cause OverflowError when decoding the prompt_ids. + # Therefore, we need to do an early replacement here + token_ids = processed_outputs['input_ids'] + token_ids[token_ids < 0] = _IMAGE_TOKEN_ID + processed_outputs['input_ids'] = token_ids + return processed_outputs + + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) -@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) +@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c3a95d60e6fe6..922c83b6fd8a9 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,7 +3,8 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache -from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union +from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol, + TypeVar, Union, cast) import torch from transformers import BatchFeature, ProcessorMixin @@ -11,7 +12,8 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of, + resolve_mm_processor_kwargs) from .inputs import (AudioItem, ImageItem, MultiModalDataDict, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, @@ -543,8 +545,14 @@ def __init__( self.ctx = ctx self.metadata = metadata + self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs + or {}) - def _get_hf_processor(self) -> ProcessorMixin: + def _get_hf_processor( + self, + **mm_processor_kwargs: Mapping[str, object], + ) -> ProcessorMixin: + # by default, we won't pass any kwargs to the processor initialization return self.ctx.get_hf_processor() def _get_tokenizer(self) -> AnyTokenizer: @@ -581,7 +589,13 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - hf_processor = self._get_hf_processor() + # some mm_processor_kwargs may be used in processor initialization + # instead of processor call + processor_init_kwargs = { + **self.init_mm_processor_kwargs, + **mm_processor_kwargs, + } + hf_processor = self._get_hf_processor(**processor_init_kwargs) processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() @@ -601,6 +615,13 @@ def _apply_hf_processor( else: processor_data[k] = v + # filter mm_processor_kwargs used in processor call + mm_processor_kwargs = resolve_mm_processor_kwargs( + self.init_mm_processor_kwargs, + cast(Dict[str, Any], mm_processor_kwargs), + hf_processor, + ) + try: hf_inputs = hf_processor( text=prompt, # type: ignore From cbcbdb1ceb9c219d13b2386e101992c399410551 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 9 Dec 2024 22:21:06 +0100 Subject: [PATCH 09/12] [Bugfix][Hardware][Gaudi] Bump vllm_hpu_extension version (#11028) Signed-off-by: Konrad Zawora --- requirements-hpu.txt | 2 +- vllm/attention/backends/hpu_attn.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 4674efb812cfd..17d40d0ee131a 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@e096d6f diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 2c62e565c04c7..f90d15d4207e7 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -111,8 +111,16 @@ def __init__( self.matmul_qk = Matmul() self.softmax = Softmax() self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() + # NOTE(kzawora): Contiguous PA is off until model runner supports it self.k_cache = VLLMKVCache() + self.k_cache.use_contiguous_pa = False self.v_cache = VLLMKVCache() + self.v_cache.use_contiguous_pa = False + # NOTE(kzawora): Pipelined PA is off until model runner supports it + ops.pa_impl = ops.pa + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.alibi_slopes = alibi_slopes @@ -228,9 +236,12 @@ def forward( block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_scales=attn_metadata.block_scales, + block_groups=None, scale=self.scale, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, + batch2block_matmul_op=self.batch2block_matmul, + block2batch_matmul_op=self.block2batch_matmul, keys_fetch_func=self.k_cache.fetch_from_cache, values_fetch_func=self.v_cache.fetch_from_cache) # Reshape the output tensor. From 1a2f8fb828f0444705db319786b2e901159f184e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 9 Dec 2024 13:47:24 -0800 Subject: [PATCH 10/12] [v1] fix use compile sizes (#11000) Signed-off-by: youkaichao --- vllm/config.py | 1 + vllm/v1/worker/gpu_model_runner.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 29f0839dcabba..5fb9563fcf3a3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2522,6 +2522,7 @@ def __post_init__(self): self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True + self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7f95be06188e3..c601aca13feaf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -582,6 +582,9 @@ def capture_model(self) -> None: # can reuse the memory pool allocated for the large shapes. with graph_capture(): for num_tokens in reversed(self.cudagraph_batch_sizes): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(self.model, num_tokens, self.kv_caches) self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() From 9c6459e4cb020ec1ad9ea08cac9309b83d432fc8 Mon Sep 17 00:00:00 2001 From: xendo Date: Mon, 9 Dec 2024 22:53:24 +0100 Subject: [PATCH 11/12] [Neuron] Upgrade neuron to 2.20.2 (#11016) Signed-off-by: Jerzy Zagorski Co-authored-by: Jerzy Zagorski --- Dockerfile.neuron | 3 ++- vllm/utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 76dbd4c04d3f3..77162bc82de62 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -1,5 +1,6 @@ # default base image -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.0-ubuntu20.04" +# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04" FROM $BASE_IMAGE diff --git a/vllm/utils.py b/vllm/utils.py index 1f19d9eacd16d..2bb1fb2af40f4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1628,7 +1628,7 @@ def direct_register_custom_op( library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ - if is_in_doc_build(): + if is_in_doc_build() or not supports_custom_op(): return import torch.library if hasattr(torch.library, "infer_schema"): From b63ba848323efd88207b12d7582501d525503b8a Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:00:29 -0500 Subject: [PATCH 12/12] [ROCm][bugfix] scpecilative decoding worker class (#11035) Signed-off-by: Gregory Shtrasberg --- vllm/platforms/rocm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 66674e3ebe91f..0133f26a0b1bc 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -93,6 +93,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: elif vllm_config.speculative_config: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker"