Skip to content

Commit

Permalink
[Core][VLM] Stack multimodal tensors to represent multiple images wit…
Browse files Browse the repository at this point in the history
…hin each prompt (vllm-project#7902)
  • Loading branch information
petersalas authored Aug 28, 2024
1 parent 9c71c97 commit fab5f53
Show file tree
Hide file tree
Showing 15 changed files with 214 additions and 60 deletions.
2 changes: 0 additions & 2 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ Base Classes

.. autodata:: vllm.multimodal.NestedTensors

.. autodata:: vllm.multimodal.BatchedTensors

.. autodata:: vllm.multimodal.BatchedTensorInputs

.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
Expand Down
83 changes: 83 additions & 0 deletions tests/multimodal/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch

from vllm.multimodal.base import MultiModalInputs, NestedTensors


def assert_nested_tensors_equal(expected: NestedTensors,
actual: NestedTensors):
assert type(expected) == type(actual)
if isinstance(expected, torch.Tensor):
assert torch.equal(expected, actual)
else:
for expected_item, actual_item in zip(expected, actual):
assert_nested_tensors_equal(expected_item, actual_item)


def assert_multimodal_inputs_equal(expected: MultiModalInputs,
actual: MultiModalInputs):
assert set(expected.keys()) == set(actual.keys())
for key in expected:
assert_nested_tensors_equal(expected[key], actual[key])


def test_multimodal_input_batch_single_tensor():
t = torch.rand([1, 2])
result = MultiModalInputs.batch([{"image": t}])
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})


def test_multimodal_input_batch_multiple_tensors():
a = torch.rand([1, 1, 2])
b = torch.rand([1, 1, 2])
c = torch.rand([1, 1, 2])
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})


def test_multimodal_input_batch_multiple_heterogeneous_tensors():
a = torch.rand([1, 2, 2])
b = torch.rand([1, 3, 2])
c = torch.rand([1, 4, 2])
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})


def test_multimodal_input_batch_nested_tensors():
a = torch.rand([2, 3])
b = torch.rand([2, 3])
c = torch.rand([2, 3])
result = MultiModalInputs.batch([{
"image": [a]
}, {
"image": [b]
}, {
"image": [c]
}])
assert_multimodal_inputs_equal(result, {
"image":
torch.stack([a.unsqueeze(0),
b.unsqueeze(0),
c.unsqueeze(0)])
})


def test_multimodal_input_batch_heterogeneous_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(
result,
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})


def test_multimodal_input_batch_multiple_batchable_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
d = torch.rand([1, 2, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
assert_multimodal_inputs_equal(
result,
{"image": torch.stack([torch.stack([a, b]),
torch.stack([c, d])])})
7 changes: 7 additions & 0 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,9 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)

return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
Expand All @@ -564,6 +567,10 @@ def _parse_and_validate_image_input(
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)

return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,9 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)

return ChameleonImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def _parse_and_validate_image_input(
image_patches = kwargs.pop("image_patches", None)

if isinstance(image_patches, torch.Tensor):
# Remove the N dimension until multiple images are supported.
image_patches = image_patches.squeeze(1)

expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size:
raise ValueError(
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
min_num,
max_num,
use_thumbnail=use_thumbnail)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
Expand Down Expand Up @@ -410,6 +412,10 @@ def _parse_and_validate_image_input(
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

# Flatten the B and N dimensions
image_embeds = image_embeds.flatten(0, 2)

return InternVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
Expand All @@ -422,6 +428,9 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Flatten the B and N dimensions
pixel_values = pixel_values.flatten(0, 2)

return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def _parse_and_validate_image_input(
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
Expand All @@ -241,6 +245,10 @@ def _parse_and_validate_image_input(
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)

return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,14 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")

# Remove the N dimension until multiple images are supported.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.squeeze(1)
else:
pixel_values = [t.squeeze(0) for t in pixel_values]

image_sizes = image_sizes.squeeze(1)

return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
Expand All @@ -372,6 +380,9 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")

# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)

return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
Expand Down
11 changes: 8 additions & 3 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,14 @@ def _parse_and_validate_inputs(

pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for b in range(len(pixel_values)):
pixel_values_flat += pixel_values[b]
tgt_sizes_flat += tgt_sizes[b]
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
if len(pixel_b) != len(tgt_b):
raise ValueError("Inconsistent N lengths, found: "
f"{len(pixel_b)} vs {len(tgt_b)}")

for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n

# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def _parse_and_validate_image_input(
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)

return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
Expand All @@ -194,6 +198,10 @@ def _parse_and_validate_image_input(
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)

return PaliGemmaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,14 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")

# Merge the B and N dimensions.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.flatten(0, 1)
else:
pixel_values = torch.cat(pixel_values)

image_sizes = image_sizes.flatten(0, 1)

return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ def _parse_and_validate_audio_input(
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")

# Remove the N dimension until multiple audios are supported.
if isinstance(audio_features, torch.Tensor):
audio_features = audio_features.squeeze(1)
else:
audio_features = [t.squeeze(0) for t in audio_features]

return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features)

Expand All @@ -341,6 +347,9 @@ def _parse_and_validate_audio_input(
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")

# Remove the N dimension until multiple audios are supported.
audio_embeds = audio_embeds.squeeze(1)

return UltravoxAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)

Expand Down
60 changes: 37 additions & 23 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Iterable, List, Optional, Protocol, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.func import functional_call
Expand All @@ -10,7 +11,7 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal import BatchedTensors
from vllm.multimodal.base import NestedTensors
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -54,9 +55,34 @@ def init_vllm_registered_model(
)


def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively concatenates NestedTensors along any heterogeneously sized
dimensions.
"""

if isinstance(embeddings, torch.Tensor):
return embeddings

return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))


def _embedding_count_expression(embeddings: NestedTensors) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""

if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])

return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)


def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: BatchedTensors,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
Expand All @@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum()

if isinstance(multimodal_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")

inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")

inputs_embeds[mask] = torch.cat(multimodal_embeddings)
flattened = _flatten_embeddings(multimodal_embeddings)
*dims, embed_dim = flattened.shape
num_multimodal_embeddings = np.prod(dims)
if num_multimodal_embeddings != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {num_multimodal_embeddings} "
f"multimodal tokens to {num_expected_tokens} placeholders")

inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim)
return inputs_embeds


Expand Down
3 changes: 1 addition & 2 deletions vllm/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins,
from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
NestedTensors)
from .registry import MultiModalRegistry
Expand All @@ -14,7 +14,6 @@

__all__ = [
"BatchedTensorInputs",
"BatchedTensors",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalInputs",
Expand Down
Loading

0 comments on commit fab5f53

Please sign in to comment.