Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix various bugs in multi-modal processor #12031

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def test_find_replace_tokens(
"pattern_1": [32000, 32000],
"pattern_2": [],
"pattern_3": [1550, 918, 1550],
# Test different modalities having the same tokens (32000)
"pattern_4": [32000],
},
],
)
Expand All @@ -438,6 +440,14 @@ def test_find_replace_tokens(
replacement=[32000, 32000],
),
],
"pattern_4": [
PlaceholderInfo(
modality="pattern_4",
item_idx=0,
start_idx=3,
replacement=[32000],
),
],
}

),
Expand Down Expand Up @@ -466,6 +476,7 @@ def test_find_replace_tokens(
replacement=[1550, 918, 1550],
),
],
# No match for pattern_4 as it has lower priority than pattern_1
}
),
(
Expand All @@ -485,6 +496,14 @@ def test_find_replace_tokens(
replacement=[32000, 32000],
),
],
"pattern_4": [
PlaceholderInfo(
modality="pattern_4",
item_idx=0,
start_idx=5,
replacement=[32000],
),
],
"pattern_3": [
PlaceholderInfo(
modality="pattern_3",
Expand Down
89 changes: 39 additions & 50 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,73 +404,62 @@ def replace_text_matches(
return "".join(texts)


def _iter_modality_placeholders(
def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
modality: str,
modality_repls: Sequence[BoundPromptReplacement],
modal_item_count: int,
mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderInfo]:
if modal_item_count == 0:
return
"""
Yield each set of placeholder tokens found in :code:`prompt`.

Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_repls` takes priority.

Note that empty matches are ignored.
"""
prompt_len = len(prompt)
item_idx = 0
item_idx_by_modality = defaultdict[str, int](lambda: 0)

start_idx = 0
while start_idx < prompt_len:
found = False

for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len

if repl_len == 0 or end_idx > prompt_len:
for modality, modality_repls in mm_prompt_repls.items():
item_idx = item_idx_by_modality[modality]
if item_idx >= mm_item_counts.get(modality, 0):
continue

if prompt[start_idx:end_idx] == repl_tokens:
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
replacement=repl_tokens,
)
for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len

if repl_len == 0 or end_idx > prompt_len:
continue

if prompt[start_idx:end_idx] == repl_tokens:
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
replacement=repl_tokens,
)

item_idx += 1
if item_idx >= modal_item_count:
return
# Exclude overlapping matches
start_idx = end_idx
item_idx_by_modality[modality] += 1
found = True
break

# Exclude overlapping matches
start_idx = end_idx
found = True
break
if found:
break # Go back to the outer while loop

if not found:
start_idx += 1


def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderInfo]:
"""
For each modality, yield each set of placeholder tokens found in
:code:`prompt`.

Note that empty matches are ignored.
"""
for modality, modal_item_count in mm_item_counts.items():
if modality in mm_prompt_repls:
yield from _iter_modality_placeholders(
prompt,
modality,
mm_prompt_repls[modality],
modal_item_count,
)


def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
Expand Down Expand Up @@ -1156,7 +1145,7 @@ def apply(

# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
if all(len(repls) == 0 for repls in mm_missing_repls.values()):
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
Expand Down
5 changes: 4 additions & 1 deletion vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def get_max_tokens_per_item_by_modality(
This is currently directly used only in V1.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len
return processor.info.get_mm_max_tokens_per_item(seq_len)
Expand Down
Loading