Skip to content

Commit

Permalink
[Model] Add Support for Multimodal Granite Models (#10291)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
alex-jw-brooks and DarkLight1337 authored Nov 21, 2024
1 parent f0e0238 commit 1cfde82
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 35 deletions.
47 changes: 35 additions & 12 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData

from .utils import get_vit_attn_backend
Expand Down Expand Up @@ -389,12 +390,20 @@ def __init__(
for layer_idx in range(num_hidden_layers)
])

def forward(self, inputs_embeds: torch.Tensor):

def forward(
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states = inputs_embeds

for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)

if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return hidden_states


Expand All @@ -419,6 +428,7 @@ def __init__(
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
Expand Down Expand Up @@ -446,16 +456,26 @@ def __init__(
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:

hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states)

if self.post_layernorm is None:
return hidden_states
return_all_hidden_states = feature_sample_layers is not None

# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states)

# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)

return self.post_layernorm(hidden_states)
return encoder_outputs


class CLIPVisionModel(nn.Module):
Expand All @@ -478,11 +498,14 @@ def __init__(
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
prefix=f"{prefix}.vision_model")

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
return self.vision_model(pixel_values, feature_sample_layers)

@property
def device(self):
Expand Down
45 changes: 37 additions & 8 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,41 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):

class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: int
vision_feature_layer: Union[int, List[int]]


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
"""Determine the number of hidden layers to initialize up to in the
visual encoder.
Args:
hf_config: Model config with vision feature layer(s).
"""
feature_layers = hf_config.vision_feature_layer
num_hidden_layers = hf_config.vision_config.num_hidden_layers
# If we have one feature layer, initialize up to that layer
if isinstance(feature_layers, int):
return _get_layer_index(feature_layers, num_hidden_layers)
# If we have multiple feature layers, initialize up to the deepest one
elif isinstance(feature_layers, (list, tuple)):
return max(
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
" is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
"""Given an signed vision feature layer, get the number of hidden layers
needed to leverage it.
Args:
feature_layer_index: Index of a required layer in the visual encoder.
num_hidden_layers: The total number of hidden layers in the visual
encoder.
"""
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index + 1


def init_vision_tower_for_llava(
Expand All @@ -216,13 +250,8 @@ def init_vision_tower_for_llava(
):
vision_config = hf_config.vision_config

# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# Initialize the vision tower only up to the deepest required feature layer
num_hidden_layers = _get_num_hidden_layers(hf_config)

if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
Expand Down
20 changes: 18 additions & 2 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config

vision_feature_layer = config.vision_feature_layer
# Determine the layer up to which we will initialize the vision tower
if isinstance(vision_feature_layer, int):
vision_hidden_size = config.vision_config.hidden_size
self.feature_sample_layers = None
# Used for multimodal granite models to control encoder outputs
elif isinstance(vision_feature_layer, (list, tuple)):
vision_hidden_size = config.vision_config.hidden_size * len(
vision_feature_layer)
self.feature_sample_layers = vision_feature_layer
else:
raise TypeError(
f"vision_layer_feature type: {type(vision_feature_layer)}"
" is not supported")

self.config = config
self.multimodal_config = multimodal_config

Expand All @@ -300,7 +315,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)

Expand Down Expand Up @@ -419,7 +434,8 @@ def _image_pixels_to_features(

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features = vision_tower(
pixel_values, feature_sample_layers=self.feature_sample_layers)

return self._select_image_features(
image_features,
Expand Down
28 changes: 25 additions & 3 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
consecutive_placeholder_ranges,
resolve_visual_encoder_outputs)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
Expand Down Expand Up @@ -970,9 +971,18 @@ def forward(
x: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
return_all_hidden_states: bool,
) -> torch.Tensor:
hidden_states_pool = []

for layer in self.layers:
x = layer(x, attention_mask, position_embeddings)
if return_all_hidden_states:
hidden_states_pool.append(x)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return x


Expand All @@ -990,6 +1000,7 @@ def __init__(
super().__init__()

self.config = config

self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
Expand Down Expand Up @@ -1024,13 +1035,17 @@ def __init__(
def forward(
self,
pixel_values: List[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
"""
Args:
pixel_values: Each image to be processed will be a separate tensor
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
feature_sample_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
Returns:
image_features: tensor of token features for
Expand Down Expand Up @@ -1065,8 +1080,15 @@ def forward(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)

out = self.transformer(patch_embeds, attention_mask,
position_embedding)
return_all_hidden_states = feature_sample_layers is not None
out = self.transformer(
patch_embeds,
attention_mask,
position_embedding,
return_all_hidden_states=return_all_hidden_states)

out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)

return out

Expand Down
42 changes: 32 additions & 10 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData

from .utils import get_vit_attn_backend
Expand Down Expand Up @@ -450,11 +451,19 @@ def __init__(
def forward(
self,
inputs_embeds: torch.Tensor,
) -> torch.Tensor:
return_all_hidden_states: bool,
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states = inputs_embeds

for encoder_layer in self.layers:
hidden_states, _ = encoder_layer(hidden_states)

if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return hidden_states


Expand Down Expand Up @@ -509,6 +518,7 @@ def __init__(
embed_dim = config.hidden_size

self.embeddings = SiglipVisionEmbeddings(config)

self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
Expand Down Expand Up @@ -546,23 +556,33 @@ def forward(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = True,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:

hidden_states = self.embeddings(
pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)

encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return_all_hidden_states = feature_sample_layers is not None

# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states,
)

if self.post_layernorm is None:
return encoder_outputs
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)

last_hidden_state = self.post_layernorm(encoder_outputs)
# TODO: add this back when pooled_output is used in inference
# TODO: add this back when pooled_output is used in inference.
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
# pooled_output = self.head(encoder_outputs)

return last_hidden_state
return encoder_outputs


class SiglipVisionModel(nn.Module):
Expand Down Expand Up @@ -595,10 +615,12 @@ def forward(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
return self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
feature_sample_layers=feature_sample_layers,
)

def load_weights(self, weights: Iterable[Tuple[str,
Expand Down
Loading

0 comments on commit 1cfde82

Please sign in to comment.