Skip to content

Commit

Permalink
[V1] Refactor model executable interface for multimodal models (vllm-…
Browse files Browse the repository at this point in the history
…project#10570)

Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored Nov 26, 2024
1 parent 7576cd3 commit 2f0a0a1
Show file tree
Hide file tree
Showing 18 changed files with 581 additions and 306 deletions.
61 changes: 37 additions & 24 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData

Expand Down Expand Up @@ -609,13 +610,33 @@ def _process_image_input(self,

return self.language_projection(query_output)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
BLIP2_IMAGE_TOKEN_ID)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
"""Run forward pass for BLIP-2.
Expand Down Expand Up @@ -648,32 +669,24 @@ def forward(
See also:
:class:`Blip2ImageInputs`
"""

if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)

input_ids = None
else:
inputs_embeds = None

hidden_states = self.language_model.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)

return hidden_states

Expand Down
58 changes: 41 additions & 17 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
Expand All @@ -38,7 +39,7 @@
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, merge_multimodal_embeddings)

# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
Expand Down Expand Up @@ -987,34 +988,57 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(pixel_values),
)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype))
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:

inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:

if intermediate_tensors is not None:
inputs_embeds = None

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.torch_dtype))
image_token_id = self.model.vocabulary_mapping.image_token_id
special_image_mask = input_ids == image_token_id
image_tokens = image_tokens.to(input_ids.device,
input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask,
image_tokens)

hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)

hidden_states = self.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

def compute_logits(
Expand Down
54 changes: 35 additions & 19 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
Expand Down Expand Up @@ -545,33 +546,48 @@ def _parse_and_validate_image_input(
""")
return GLMImagePixelInputs(pixel_values=pixel_values)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input["pixel_values"] is None:
return None
pixel_values = image_input["pixel_values"].to(
dtype=self.config.torch_dtype)
vision_embeddings = self.vision(pixel_values)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=multimodal_embeddings,
boi_token_id=self.config.boi_token_id,
eoi_token_id=self.config.eoi_token_id)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is None:
inputs_embeds = self.embedding(input_ids)
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input["pixel_values"] is not None:
pixel_values = image_input["pixel_values"].to(
dtype=inputs_embeds.dtype)
image_embeds = self.vision(pixel_values)

boi_token_id = self.config.boi_token_id
eoi_token_id = self.config.eoi_token_id

inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=image_embeds,
boi_token_id=boi_token_id,
eoi_token_id=eoi_token_id)

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
if intermediate_tensors is None and inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
else:
inputs_embeds = intermediate_tensors["hidden_states"]

Expand Down
43 changes: 29 additions & 14 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
Expand Down Expand Up @@ -302,31 +303,45 @@ def _process_image_input(
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)

else:
inputs_embeds = None

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

hidden_states = self.language_model(
input_ids=input_ids,
Expand Down
36 changes: 35 additions & 1 deletion vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
Protocol, Type, Union, overload, runtime_checkable)

import torch
from typing_extensions import TypeIs
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.utils import supports_kw

from .interfaces_base import is_embedding_model

if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.multimodal.inputs import NestedTensors # noqa: F401
from vllm.sequence import IntermediateTensors

logger = init_logger(__name__)

T = TypeVar("T", default="NestedTensors")


@runtime_checkable
class SupportsMultiModal(Protocol):
Expand All @@ -28,6 +32,36 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""

def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
"""
...

# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor:
...

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[T] = None,
) -> torch.Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
"""
...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
Expand Down
Loading

0 comments on commit 2f0a0a1

Please sign in to comment.