Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Clean up multi-modal processor #11207

Merged
merged 2 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,7 @@ def run_fuyu(question: str, modality: str):
def run_phi3v(question: str, modality: str):
assert modality == "image"

prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
# Note: The default setting of max_num_seqs (256) and
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"

# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
Expand Down
17 changes: 8 additions & 9 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import pytest

from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand Down Expand Up @@ -314,8 +313,8 @@ def test_find_replace_text(
result = replace_text_matches(
prompt,
matches,
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
{key: mm_count
for key in repl_by_key},
)

# Only displayed on error
Expand Down Expand Up @@ -380,8 +379,8 @@ def test_find_replace_tokens(
result = replace_token_matches(
prompt,
matches,
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
{key: mm_count
for key in repl_by_key},
)

# Only displayed on error
Expand Down Expand Up @@ -476,7 +475,7 @@ def test_iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
{key: 3 for key in repl_by_key},
))

# Only displayed on error
Expand Down
48 changes: 23 additions & 25 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,18 +403,17 @@ def _resolve_matches(
def _replace_matches(
prompt: _S,
matches: Sequence[_PromptReplacementMatch],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> list[_S]:
out_seqs = list[_S]()
prev_end_idx = 0
next_idx_by_modality = {modality: 0 for modality in mm_items}
next_idx_by_modality = {modality: 0 for modality in mm_item_counts}

for match in _resolve_matches(prompt, matches):
modality = match.modality
modal_items = mm_items[modality]

item_idx = next_idx_by_modality[modality]
if item_idx >= len(modal_items):
if item_idx >= mm_item_counts[modality]:
continue

start_idx = match.start_idx
Expand All @@ -441,27 +440,27 @@ def _replace_matches(
def replace_token_matches(
prompt: list[int],
matches: Sequence[_PromptReplacementTokenMatch],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt

token_id_seqs = _replace_matches(prompt, matches, mm_items)
token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)

return flatten_2d_lists(token_id_seqs)


def replace_text_matches(
prompt: str,
matches: Sequence[_PromptReplacementTextMatch],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt

texts = _replace_matches(prompt, matches, mm_items)
texts = _replace_matches(prompt, matches, mm_item_counts)

return "".join(texts)

Expand All @@ -470,9 +469,9 @@ def _iter_modality_placeholders(
prompt: list[int],
modality: str,
modality_repls: Sequence[_BoundPromptReplacement],
modal_items: list[Any],
modal_item_count: int,
) -> Iterable[_PlaceholderInfo]:
if len(modal_items) == 0:
if modal_item_count == 0:
return

prompt_len = len(prompt)
Expand All @@ -499,7 +498,7 @@ def _iter_modality_placeholders(
)

item_index += 1
if item_index >= len(modal_items):
if item_index >= modal_item_count:
return

# Exclude overlapping matches
Expand All @@ -514,7 +513,7 @@ def _iter_modality_placeholders(
def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement],
prompt: list[int],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]:
"""
Yield each set of placeholder tokens found in :code:`prompt`.
Expand All @@ -523,13 +522,13 @@ def iter_placeholders(
"""
repls_by_modality = dict(full_groupby_modality(prompt_repls))

for modality, modal_items in mm_items.items():
for modality, modal_item_count in mm_item_counts.items():
if modality in repls_by_modality:
yield from _iter_modality_placeholders(
prompt,
modality,
repls_by_modality[modality],
modal_items,
modal_item_count,
)


Expand Down Expand Up @@ -590,10 +589,10 @@ def _find_placeholders(
self,
all_prompt_repls: Sequence[_BoundPromptReplacement],
new_token_ids: list[int],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> list[_PlaceholderInfo]:
return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_items))
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))

def _apply_hf_processor(
self,
Expand Down Expand Up @@ -655,10 +654,9 @@ def _bind_prompt_replacements(

def _apply_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
tokenizer = self._get_tokenizer()

Expand All @@ -675,13 +673,13 @@ def _apply_prompt_replacements(
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
if all(
len(matches) >= len(mm_items[modality])
len(matches) >= mm_item_counts[modality]
for modality, matches in full_groupby_modality(token_matches)
): # yapf: disable
token_ids = replace_token_matches(
token_ids,
token_matches,
mm_items,
mm_item_counts,
)

text = _decode(tokenizer, token_ids)
Expand All @@ -693,14 +691,14 @@ def _apply_prompt_replacements(
text = replace_text_matches(
text,
text_matches,
mm_items,
mm_item_counts,
)

token_ids = _encode(tokenizer, text)
matched_repls = [match.prompt_repl for match in text_matches]

placeholders = self._find_placeholders(matched_repls, token_ids,
mm_items)
mm_item_counts)

return token_ids, text, placeholders

Expand Down Expand Up @@ -737,8 +735,9 @@ def apply(

# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
mm_item_counts = {m: len(items) for m, items in mm_items.items()}
all_placeholders = self._find_placeholders(all_prompt_repls,
prompt_ids, mm_items)
prompt_ids, mm_item_counts)

if all_placeholders:
prompt_text = _decode(tokenizer, prompt_ids)
Expand All @@ -748,10 +747,9 @@ def apply(
prompt_text,
all_placeholders,
) = self._apply_prompt_replacements(
mm_items,
hf_inputs,
prompt_ids,
all_prompt_repls,
mm_item_counts,
)

mm_placeholders = {
Expand Down
Loading