diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index bee968b4d2e43..c6f8316412e2f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -356,7 +356,7 @@ steps: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' -- label: Multi-Modal Models Test (Standard) # 28min +- label: Multi-Modal Models Test (Standard) # 40min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -372,7 +372,7 @@ steps: - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model -- label: Multi-Modal Models Test (Extended) 1 # 1h16m +- label: Multi-Modal Models Test (Extended) 1 # 48m optional: true source_file_dependencies: - vllm/ diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 6c50882d83c3b..ffd6891b25965 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -33,7 +33,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs, +from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config): def mm_input_mapper_for_glmv( ctx: InputContext, - data: MultiModalData[object], + data: ModalityData[object], ) -> Dict: model_config = ctx.model_config tokenizer = cached_get_tokenizer( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0ecba5a1cae0f..1d6ee2a0be72e 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -20,11 +20,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, - MultiModalFieldConfig, MultiModalInputsV2, - MultiModalKwargs, NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import ImageProcessorItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement, + MultiModalDataItems, ProcessorInputs, + PromptReplacement, full_groupby_modality) from vllm.sequence import IntermediateTensors @@ -179,7 +181,9 @@ def _get_prompt_replacements( assert isinstance(vision_config, PixtralVisionConfig) def get_replacement_pixtral(item_idx: int): - image_size = mm_items.get_image_size(item_idx) + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + ( num_width_tokens, num_height_tokens, @@ -591,8 +595,8 @@ def apply( result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) - mm_items = self._get_mm_items(mm_data) - mm_item_counts = mm_items.get_item_counts() + mm_items = self._to_mm_items(mm_data) + mm_item_counts = mm_items.get_all_counts() mm_kwargs = result["mm_kwargs"] # We reimplement the functionality of MLlavaProcessor from diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index fefa9fd62d1d0..15362db6cdfbf 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -32,12 +32,13 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, - MultiModalFieldConfig, MultiModalInputsV2, - MultiModalKwargs, NestedTensors, - PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors, PlaceholderRange) +from vllm.multimodal.parse import ImageProcessorItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement, + MultiModalDataItems, ProcessorInputs, + PromptReplacement, _BoundPromptReplacement, _PlaceholderInfo) from vllm.sequence import IntermediateTensors @@ -381,7 +382,9 @@ def _get_prompt_replacements( assert isinstance(bos_token_id, int) def get_replacement_phi3v(item_idx: int): - image_size = mm_items.get_image_size(item_idx) + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + num_tokens = image_processor.calc_num_image_tokens_from_image_size( width=image_size.width, height=image_size.height, @@ -389,12 +392,14 @@ def get_replacement_phi3v(item_idx: int): return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id] + num_images = mm_items.get_count("image", strict=False) + return [ PromptReplacement( modality="image", target=image_token, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:len(mm_items.images)] + ) for image_token in image_tokens[:num_images] ] def _apply_prompt_replacements( diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 25a351bd9c656..e3d43b017f894 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -20,8 +20,8 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from functools import cached_property -from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import numpy as np import torch @@ -38,10 +38,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement) + MultiModalDataItems, ProcessorInputs, + PromptReplacement) from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -99,15 +101,9 @@ def _get_hf_processor( def _get_feature_extractor(self) -> WhisperFeatureExtractor: return self._get_hf_processor().feature_extractor # type: ignore - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - # resample audio to the model's sampling rate + def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self._get_feature_extractor() - mm_items.resample_audios(feature_extractor.sampling_rate) - - return super()._get_hf_mm_data(mm_items) + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 574845ef5a525..6181fe3dd13d8 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -25,7 +25,6 @@ from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Set, Tuple, Type, TypedDict, Union) -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -55,15 +54,16 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, +from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) + NestedTensors, VideoItem) +from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement) + MultiModalDataItems, ProcessorInputs, + PromptReplacement) from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope -from vllm.utils import is_list_of from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, @@ -719,61 +719,81 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key="video") -class Qwen2VLMultiModalDataItems(MultiModalDataItems): +class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], + dict[str, torch.Tensor]]): - @staticmethod - def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": - """ - Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. - """ - multi_data = Qwen2VLMultiModalDataItems() - - for k, v in data.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - # yapf: disable - if k == "video": - # Special case since even a single item can be a list - multi_data[k] = ( # type: ignore[index] - v if ( - isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment] - or is_list_of(v, list) - or isinstance(v[0], (np.ndarray, torch.Tensor)) - and v[0].ndim == 4 - ) else [v] - ) - elif k in ("image", "audio"): - multi_data[k] = ( # type: ignore[index] - v if isinstance(v, (dict, torch.Tensor, list)) else [v] - ) - else: - multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] - # yapf: enable + def __init__(self, data: dict, modality: str) -> None: + super().__init__(data) - return multi_data + self.modality = modality - def get_item_counts(self) -> Mapping[str, int]: - return { - m: ( - len(items[f"{m}_grid_thw"]) # type: ignore - if isinstance(items, dict) else len(items)) - for m, items in self.items() - } + grid_thw = data[f"{modality}_grid_thw"] + slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist() + self._slices = [ + slice(slice_idxs[i], slice_idxs[i + 1]) + for i in range(len(grid_thw)) + ] - def has_embedding_inputs(self) -> bool: - return any( - isinstance(items, dict) or any( - isinstance(item, torch.Tensor) for item in items) - for items in self.values()) + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r})") + def get_count(self) -> int: + return len(self.data[f"{self.modality}_grid_thw"]) -class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): + def get(self, index: int) -> dict[str, torch.Tensor]: + out = {} + for k, v in self.data.items(): + if v != f"{self.modality}_grid_thw": + v = v[self._slices[index]] + + out[k] = v + + return out + + def get_processor_data(self) -> Mapping[str, object]: + return {} + + def get_passthrough_data(self) -> Mapping[str, object]: + return self.data + + +class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems): + + def __init__(self, data: dict) -> None: + super().__init__(data, "image") + + +class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems): - def _get_mm_items( + def __init__(self, data: dict) -> None: + super().__init__(data, "video") + + +class Qwen2MultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return Qwen2EmbeddingItems(data, modality="image") + + return super()._parse_image_data(data) + + def _parse_video_data( self, - mm_data: MultiModalDataDict, - ) -> MultiModalDataItems: - return Qwen2VLMultiModalDataItems.from_dict(mm_data) + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return Qwen2EmbeddingItems(data, modality="video") + + return super()._parse_video_data(data) + + +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): + + def _get_data_parser(self) -> MultiModalDataParser: + return Qwen2MultiModalDataParser() def _get_hf_processor( self, @@ -796,35 +816,6 @@ def _get_hf_processor( return hf_processor - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - processor_data = dict[str, Any]() - passthrough_data = dict[str, Any]() - - for k, v in mm_items.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - if k in ("image", "video", "audio"): - if isinstance(v, dict): - # Pass through embedding inputs (dict) - passthrough_data.update(v) - elif isinstance(v, torch.Tensor) and v.ndim == 3: - # Pass through embedding inputs (single) - passthrough_data[f"{k}_embeds"] = [v] - elif (is_list_of(v, torch.Tensor) and len(v) > 0 - and v[0].ndim == 2): - # Pass through embedding inputs (multi) - passthrough_data[f"{k}_embeds"] = v - elif len(v) > 0: - # Map keys to plural form, e.g.: image -> images - processor_data[f"{k}s"] = v - else: - processor_data[k] = v - - return processor_data, passthrough_data - def _get_prompt_replacements( self, mm_items: MultiModalDataItems, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 7b4aeeec5f403..7e853e5b90096 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,8 +3,8 @@ import math from functools import cached_property, lru_cache -from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) import numpy as np import torch @@ -24,10 +24,12 @@ from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement) + MultiModalDataItems, ProcessorInputs, + PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils import is_list_of @@ -85,15 +87,9 @@ def _get_feature_extractor(self) -> WhisperFeatureExtractor: hf_processor = self._get_hf_processor() return hf_processor.audio_processor.feature_extractor # type: ignore - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - # resample audio to the model's sampling rate + def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self._get_feature_extractor() - mm_items.resample_audios(feature_extractor.sampling_rate) - - return super()._get_hf_mm_data(mm_items) + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 9255e062e4870..e58bbe81717a0 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,8 +1,7 @@ from .base import MultiModalPlaceholderMap, MultiModalPlugin -from .inputs import (BatchedTensorInputs, MultiModalData, - MultiModalDataBuiltins, MultiModalDataDict, - MultiModalKwargs, MultiModalPlaceholderDict, - NestedTensors) +from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, + MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -16,7 +15,7 @@ __all__ = [ "BatchedTensorInputs", - "MultiModalData", + "ModalityData", "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalKwargs", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 3e09ef1fcbb56..de80f22bac2a3 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -9,7 +9,7 @@ from vllm.utils import PlaceholderModule from .base import MediaIO, MultiModalPlugin -from .inputs import AudioItem, MultiModalData, MultiModalKwargs +from .inputs import AudioItem, ModalityData, MultiModalKwargs try: import librosa @@ -31,7 +31,7 @@ def get_data_key(self) -> str: def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[AudioItem], + data: ModalityData[AudioItem], **mm_processor_kwargs, ) -> MultiModalKwargs: raise NotImplementedError("There is no default audio input mapper") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index cdda6f8052794..7f4029e726332 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -15,12 +15,12 @@ from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, +from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, PlaceholderRange) logger = init_logger(__name__) -MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], +MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], MultiModalKwargs] """ Return a dictionary to be passed as keyword arguments to @@ -69,7 +69,7 @@ def get_data_key(self) -> str: def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[Any], + data: ModalityData[Any], **mm_processor_kwargs, ) -> MultiModalKwargs: """ @@ -118,7 +118,7 @@ def wrapper(model_cls: N) -> N: def map_input( self, model_config: "ModelConfig", - data: MultiModalData[Any], + data: ModalityData[Any], mm_processor_kwargs: Optional[dict[str, Any]], ) -> MultiModalKwargs: """ diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 14c79dfadec0c..da13a381c4530 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -13,7 +13,7 @@ from vllm.utils import is_list_of from .base import MediaIO, MultiModalPlugin -from .inputs import ImageItem, MultiModalData, MultiModalKwargs +from .inputs import ImageItem, ModalityData, MultiModalKwargs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -44,7 +44,7 @@ def _get_hf_image_processor( def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[ImageItem], + data: ModalityData[ImageItem], **mm_processor_kwargs, ) -> MultiModalKwargs: model_config = ctx.model_config diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 1fbda6e0b8750..db489af7ac475 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -2,53 +2,74 @@ from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass -from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast, - final) +from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final import numpy as np import torch import torch.types from PIL.Image import Image from transformers import BatchFeature -from typing_extensions import NotRequired, TypeAlias, assert_never +from typing_extensions import NotRequired, TypeAlias from vllm.utils import JSONTree, is_list_of, json_map_leaves _T = TypeVar("_T") -# yapf: disable -ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] +HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] """ A :class:`transformers.image_utils.ImageInput` representing a single image item, which can be passed to a HuggingFace :code:`ImageProcessor`. """ -VideoItem: TypeAlias = Union[ - list[Image], - np.ndarray, - torch.Tensor, - list[np.ndarray], - list[torch.Tensor], -] +HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor, + list[np.ndarray], list[torch.Tensor]] """ A :class:`transformers.image_utils.VideoInput` representing a single video item, which can be passed to a HuggingFace :code:`VideoProcessor`. """ -AudioItem: TypeAlias = Union[ - np.ndarray, - list[float], - # `(audio, sampling_rate)`: If the audio's sampling rate is different - # from that expected by the model, we need to resample it. - tuple[np.ndarray, float], -] +HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor] """ Represents a single audio item, which can be passed to a HuggingFace :code:`AudioProcessor`. """ -# yapf: enable -MultiModalData: TypeAlias = Union[_T, list[_T]] +ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor] +""" +A :class:`transformers.image_utils.ImageInput` representing a single image +item, which can be passed to a HuggingFace :code:`ImageProcessor`. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as image embeddings; +these are directly passed to the model without HF processing. +""" + +VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor] +""" +A :class:`transformers.image_utils.VideoInput` representing a single video +item, which can be passed to a HuggingFace :code:`VideoProcessor`. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as video embeddings; +these are directly passed to the model without HF processing. +""" + +AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], + torch.Tensor] +""" +Represents a single audio +item, which can be passed to a HuggingFace :code:`AudioProcessor`. + +Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate +is different from that expected by the model; +these are resampled to the model's sampling rate before being processed by HF. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as audio embeddings; +these are directly passed to the model without HF processing. +""" + +ModalityData: TypeAlias = Union[_T, list[_T]] """ Either a single data item, or a list of data items. @@ -61,17 +82,17 @@ class MultiModalDataBuiltins(TypedDict, total=False): """Type annotations for modality types predefined by vLLM.""" - image: MultiModalData[ImageItem] + image: ModalityData[ImageItem] """The input image(s).""" - video: MultiModalData[VideoItem] + video: ModalityData[VideoItem] """The input video(s).""" - audio: MultiModalData[AudioItem] + audio: ModalityData[AudioItem] """The input audio(s).""" -MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]] +MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]] """ A dictionary containing an entry for each modality type to input. @@ -83,123 +104,6 @@ class MultiModalDataBuiltins(TypedDict, total=False): """ -class ImageSize(NamedTuple): - width: int - height: int - - -class MultiModalDataItems(UserDict[str, list[Any]]): - """ - As :class:`MultiModalDataDict`, but normalized such that each entry - corresponds to a list. - """ - - @staticmethod - def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": - """ - Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. - """ - multi_data = MultiModalDataItems() - - for k, v in data.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - # yapf: disable - if k == "video": - # Special case since even a single item can be a list - multi_data[k] = ( # type: ignore[index] - v if ( - isinstance(v, torch.Tensor) - or is_list_of(v, list) - or isinstance(v[0], (np.ndarray, torch.Tensor)) - and v[0].ndim == 4 - ) else [v] - ) - elif k in ("image", "audio"): - multi_data[k] = ( # type: ignore[index] - v if isinstance(v, (torch.Tensor, list)) else [v] - ) - else: - multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] - # yapf: enable - - return multi_data - - # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to - # `self.images` doesn't update this dictionary, which may be confusing - # We annotate the getter methods as `Sequence` to prevent others from - # trying to update the list in this way - @property - def images(self) -> Sequence[ImageItem]: - return self.get("image", []) - - @property - def videos(self) -> Sequence[VideoItem]: - return self.get("video", []) - - @property - def audios(self) -> Sequence[AudioItem]: - return self.get("audio", []) - - def get_item_counts(self) -> Mapping[str, int]: - return {m: len(items) for m, items in self.items()} - - def has_embedding_inputs(self) -> bool: - return any( - any(isinstance(item, torch.Tensor) for item in items) - for items in self.values()) - - def get_image_size(self, item_idx: int) -> ImageSize: - image = self.images[item_idx] - - if isinstance(image, Image): - return ImageSize(*image.size) - if isinstance(image, (np.ndarray, torch.Tensor)): - _, h, w = image.shape - return ImageSize(w, h) - - assert_never(image) - - def get_audio_with_sr( - self, - item_idx: int, - *, - default_sr: float, - ) -> tuple[np.ndarray, float]: - audio = self.audios[item_idx] - - if isinstance(audio, tuple): - return audio - if isinstance(audio, list): - return np.array(audio), default_sr - if isinstance(audio, np.ndarray): - return audio, default_sr - - assert_never(audio) - - def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None: - """ - If :code:`drop_sr=True`, the audio items in this dictionary are updated - to be NumPy arrays which implicitly means that their sampling rate is - the same as the model's expected sampling rate; otherwise, they remain - as :code:`(audio, new_sr)` tuples. - """ - # Avoid circular import - from .audio import resample_audio - - if not self.audios: - return - - new_audios = [] - for item_idx in range(len(self.audios)): - audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr) - audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr) - - new_audios.append(audio if drop_sr else (audio, new_sr)) - - self["audio"] = new_audios - - class PlaceholderRange(TypedDict): """ Placeholder location information for multi-modal data. @@ -436,7 +340,7 @@ def from_items_by_key( ) -> "MultiModalKwargs": data = { key: items[0].field.reduce(items).data - for key, items in items_by_key.items() + for key, items in items_by_key.items() if len(items) > 0 } return MultiModalKwargs(data, @@ -567,6 +471,11 @@ def get_items_by_modality( Get the keyword arguments corresponding to an item identified by its modality and index. """ + if modality not in self._keys_by_modality: + available_modalities = set(self._keys_by_modality.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + keys_to_gather = self._keys_by_modality[modality] return { diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py new file mode 100644 index 0000000000000..17a795247372e --- /dev/null +++ b/vllm/multimodal/parse.py @@ -0,0 +1,344 @@ +from abc import ABC, abstractmethod +from collections import UserDict +from collections.abc import Callable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar + +import numpy as np +import torch +from PIL.Image import Image +from typing_extensions import TypeAlias, TypeGuard, assert_never + +from vllm.utils import is_list_of + +from .audio import resample_audio +from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, + ImageItem, ModalityData, MultiModalDataDict, + NestedTensors, VideoItem) + +_T = TypeVar("_T") +_I = TypeVar("_I") + + +class ModalityDataItems(ABC, Generic[_T, _I]): + + def __init__(self, data: _T) -> None: + super().__init__() + + self.data = data + + def __len__(self) -> int: + return self.get_count() + + def __getitem__(self, index: int) -> _I: + return self.get(index) + + if TYPE_CHECKING: + # Auto-generated + def __iter__(self) -> Iterator[_I]: + ... + + @abstractmethod + def get_count(self) -> int: + """Get the number of data items.""" + raise NotImplementedError + + @abstractmethod + def get(self, index: int) -> _I: + """Get a data item by its index.""" + raise NotImplementedError + + def get_all(self) -> list[_I]: + """Get all data items.""" + return [self.get(idx) for idx in range(self.get_count())] + + @abstractmethod + def get_processor_data(self) -> Mapping[str, object]: + """Get the data to pass to the HF processor.""" + raise NotImplementedError + + @abstractmethod + def get_passthrough_data(self) -> Mapping[str, object]: + """Get the data to pass directly to the model.""" + raise NotImplementedError + + +class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): + + def __init__(self, data: Sequence[_T], modality: str) -> None: + super().__init__(data) + + self.modality = modality + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r})") + + def get_count(self) -> int: + return len(self.data) + + def get(self, index: int) -> _T: + return self.data[index] + + def get_processor_data(self) -> Mapping[str, object]: + return {f"{self.modality}s": self.data} + + def get_passthrough_data(self) -> Mapping[str, object]: + return {} + + +class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]): + + def __init__(self, data: NestedTensors, modality: str) -> None: + super().__init__(data) + + self.modality = modality + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r})") + + def get_count(self) -> int: + return len(self.data) + + def get(self, index: int) -> object: + return self.data[index] + + def get_processor_data(self) -> Mapping[str, object]: + return {} + + def get_passthrough_data(self) -> Mapping[str, object]: + return {f"{self.modality}_embeds": self.data} + + +class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): + + def __init__(self, data: Sequence[HfAudioItem]) -> None: + super().__init__(data, "audio") + + +class AudioEmbeddingItems(EmbeddingItems): + + def __init__(self, data: NestedTensors) -> None: + super().__init__(data, "audio") + + +class ImageSize(NamedTuple): + width: int + height: int + + +class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): + + def __init__(self, data: Sequence[HfImageItem]) -> None: + super().__init__(data, "image") + + def get_image_size(self, item_idx: int) -> ImageSize: + image = self.get(item_idx) + + if isinstance(image, Image): + return ImageSize(*image.size) + if isinstance(image, (np.ndarray, torch.Tensor)): + _, h, w = image.shape + return ImageSize(w, h) + + assert_never(image) + + +class ImageEmbeddingItems(EmbeddingItems): + + def __init__(self, data: NestedTensors) -> None: + super().__init__(data, "image") + + +class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): + + def __init__(self, data: Sequence[HfVideoItem]) -> None: + super().__init__(data, "video") + + +class VideoEmbeddingItems(EmbeddingItems): + + def __init__(self, data: NestedTensors) -> None: + super().__init__(data, "video") + + +_D = TypeVar("_D", bound=ModalityDataItems[Any, Any]) + + +class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): + """ + As :class:`MultiModalDataDict`, but normalized such that each entry + corresponds to a list. + """ + + def get_count(self, modality: str, *, strict: bool = True) -> int: + """ + Get the number of data items belonging to a modality. + + If `strict=False`, return `0` instead of raising :exc:`KeyError` + even if the modality is not found. + """ + if modality not in self: + if strict: + available_modalities = set(self.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + + return 0 + + return self[modality].get_count() + + def get_all_counts(self) -> Mapping[str, int]: + """Get the number of items belonging to each modality.""" + return {m: items.get_count() for m, items in self.items()} + + def get_items( + self, + modality: str, + typ: type[_D], + ) -> _D: + """ + Get the data items belonging to a modality, + requiring that they belong to a certain type. + """ + if modality not in self: + available_modalities = set(self.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + + items = self[modality] + if not isinstance(items, typ): + raise TypeError(f"Invalid type of data items for {modality=}. " + f"Expected type: {typ}, but " + f"found type: {type(items)}") + + return items + + +ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], + ModalityDataItems[Any, Any]] + + +class MultiModalDataParser: + """ + Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`. + """ + + def __init__(self, *, target_sr: Optional[float] = None) -> None: + super().__init__() + + self.target_sr = target_sr + + def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]: + if isinstance(data, torch.Tensor): + return data.ndim == 3 + if is_list_of(data, torch.Tensor): + return len(data) == 0 or data[0].ndim == 2 + + return False + + def _get_audio_with_sr( + self, + audio: AudioItem, + ) -> tuple[np.ndarray, Optional[float]]: + if isinstance(audio, tuple): + return audio + if isinstance(audio, list): + return np.array(audio), None + if isinstance(audio, np.ndarray): + return audio, None + if isinstance(audio, torch.Tensor): + return audio.numpy(), None + + assert_never(audio) + + def _parse_audio_data( + self, + data: ModalityData[AudioItem], + ) -> ModalityDataItems[Any, Any]: + if self._is_embeddings(data): + return AudioEmbeddingItems(data) + + if (is_list_of(data, float) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 1 + or isinstance(data, tuple)): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + new_audios = list[np.ndarray]() + for data_item in data_items: + audio, orig_sr = self._get_audio_with_sr(data_item) + if orig_sr is None: + new_audio = audio + else: + target_sr = self.target_sr + if target_sr is None: + raise RuntimeError( + "Audio resampling is not supported when " + "`target_sr` is not provided") + + new_audio = resample_audio(audio, + orig_sr=orig_sr, + target_sr=target_sr) + + new_audios.append(new_audio) + + return AudioProcessorItems(new_audios) + + def _parse_image_data( + self, + data: ModalityData[ImageItem], + ) -> ModalityDataItems[Any, Any]: + if self._is_embeddings(data): + return ImageEmbeddingItems(data) + + if (isinstance(data, Image) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 3): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + return ImageProcessorItems(data_items) + + def _parse_video_data( + self, + data: ModalityData[VideoItem], + ) -> ModalityDataItems[Any, Any]: + if self._is_embeddings(data): + return VideoEmbeddingItems(data) + + if (is_list_of(data, Image) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 4): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + return VideoProcessorItems(data_items) + + def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: + return { + "audio": self._parse_audio_data, + "image": self._parse_image_data, + "video": self._parse_video_data, + } + + def parse_mm_data(self, + mm_data: MultiModalDataDict) -> MultiModalDataItems: + subparsers = self._get_subparsers() + + mm_items = MultiModalDataItems() + for k, v in mm_data.items(): + if k not in subparsers: + raise ValueError(f"Unsupported modality: {k}") + + mm_items[k] = subparsers[k](v) + + return mm_items diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 3ece0762e3228..180489166b407 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -15,11 +15,12 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import LRUCache, flatten_2d_lists, full_groupby -from .inputs import (MultiModalDataDict, MultiModalDataItems, - MultiModalFieldConfig, MultiModalFieldItem, - MultiModalInputsV2, MultiModalKwargs, PlaceholderRange) +from .inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs, + PlaceholderRange) +from .parse import MultiModalDataItems, MultiModalDataParser logger = init_logger(__name__) @@ -621,6 +622,16 @@ def __call__( ) -> MultiModalInputsV2: return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + def _get_data_parser(self) -> MultiModalDataParser: + """ + Construct a data parser to preprocess multi-modal data items + before passing them to :meth:`_get_hf_mm_data`. + + You can support additional modalities by creating a subclass + of :class:`MultiModalDataParser` that has additional subparsers. + """ + return MultiModalDataParser() + def _get_hf_processor(self) -> ProcessorMixin: """ Subclasses can add keyword arguments to this method to accept @@ -631,11 +642,16 @@ def _get_hf_processor(self) -> ProcessorMixin: def _get_tokenizer(self) -> AnyTokenizer: return self.ctx.tokenizer - def _get_mm_items( + def _to_mm_items( self, mm_data: MultiModalDataDict, ) -> MultiModalDataItems: - return MultiModalDataItems.from_dict(mm_data) + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` + before passing them to :meth:`_get_hf_mm_data`. + """ + parser = self._get_data_parser() + return parser.parse_mm_data(mm_data) @abstractmethod def _get_mm_fields_config( @@ -680,22 +696,9 @@ def _get_hf_mm_data( processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() - for k, v in mm_items.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - if k in ("image", "video", "audio"): - if isinstance(v, torch.Tensor) and v.ndim == 3: - # Pass through embedding inputs (single) - passthrough_data[f"{k}_embeds"] = [v] - elif (is_list_of(v, torch.Tensor) and len(v) > 0 - and v[0].ndim == 2): - # Pass through embedding inputs (multi) - passthrough_data[f"{k}_embeds"] = v - elif len(v) > 0: - # Map keys to plural form, e.g.: image -> images - processor_data[f"{k}s"] = v - else: - processor_data[k] = v + for items in mm_items.values(): + processor_data.update(items.get_processor_data()) + passthrough_data.update(items.get_passthrough_data()) return processor_data, passthrough_data @@ -756,7 +759,7 @@ def _apply_hf_processor_missing( cached items; instead, we rely on our own prompt replacement logic for the full text. """ - mm_missing_counts = mm_missing_data_items.get_item_counts() + mm_missing_counts = mm_missing_data_items.get_all_counts() prompt_ids, _ = self._apply_hf_processor( prompt_text=prompt_text, @@ -789,7 +792,8 @@ def _cached_apply_hf_processor( cache = self.cache model_id = self.ctx.model_config.model - if cache is None or mm_data_items.has_embedding_inputs(): + _, passthrough_data = self._get_hf_mm_data(mm_data_items) + if cache is None or passthrough_data: return self._apply_hf_processor( prompt_text=prompt_text, mm_items=mm_data_items, @@ -812,7 +816,7 @@ def _cached_apply_hf_processor( modality: [mm_data_items[modality][idx] for idx in idxs] for modality, idxs in mm_missing_idxs.items() } - mm_missing_data_items = self._get_mm_items(mm_missing_data) + mm_missing_data_items = self._to_mm_items(mm_missing_data) prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( prompt_text=prompt_text, @@ -852,7 +856,7 @@ def _cached_apply_hf_processor( mm_merged_field_items[modality] = merged_modal_items_lst if self.enable_sanity_checks: - mm_missing_counts = mm_missing_data_items.get_item_counts() + mm_missing_counts = mm_missing_data_items.get_all_counts() assert all( item_count == mm_missing_counts[modality] for modality, item_count in mm_missing_next_idx.items()), dict( @@ -865,7 +869,7 @@ def _cached_apply_hf_processor( ) if self.enable_sanity_checks: - mm_item_counts = mm_data_items.get_item_counts() + mm_item_counts = mm_data_items.get_all_counts() for modality, item_count in mm_item_counts.items(): for item_idx in range(item_count): @@ -958,7 +962,7 @@ def apply( 3. Extract information about the placeholder tokens from the processed token IDs. """ - mm_items = self._get_mm_items(mm_data) + mm_items = self._to_mm_items(mm_data) prompt_ids, mm_kwargs = self._cached_apply_hf_processor( prompt_text, @@ -975,7 +979,7 @@ def apply( # If HF processor already inserts placeholder tokens, # there is no need for us to insert them - mm_item_counts = mm_items.get_item_counts() + mm_item_counts = mm_items.get_all_counts() all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, mm_item_counts) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index b7d43c830cc46..1ad1f5abc27a2 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -15,7 +15,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import PlaceholderModule, is_list_of -from .base import MediaIO, MultiModalData +from .base import MediaIO, ModalityData from .image import ImageMediaIO, ImagePlugin from .inputs import MultiModalKwargs, VideoItem @@ -54,7 +54,7 @@ def _get_hf_video_processor( def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[VideoItem], + data: ModalityData[VideoItem], **mm_processor_kwargs, ) -> MultiModalKwargs: model_config = ctx.model_config