Skip to content

Commit

Permalink
[Bugfix] Fix various bugs in multi-modal processor (#12031)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Jan 14, 2025
1 parent ff39141 commit bb354e6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 51 deletions.
19 changes: 19 additions & 0 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def test_find_replace_tokens(
"pattern_1": [32000, 32000],
"pattern_2": [],
"pattern_3": [1550, 918, 1550],
# Test different modalities having the same tokens (32000)
"pattern_4": [32000],
},
],
)
Expand All @@ -438,6 +440,14 @@ def test_find_replace_tokens(
replacement=[32000, 32000],
),
],
"pattern_4": [
PlaceholderInfo(
modality="pattern_4",
item_idx=0,
start_idx=3,
replacement=[32000],
),
],
}
),
Expand Down Expand Up @@ -466,6 +476,7 @@ def test_find_replace_tokens(
replacement=[1550, 918, 1550],
),
],
# No match for pattern_4 as it has lower priority than pattern_1
}
),
(
Expand All @@ -485,6 +496,14 @@ def test_find_replace_tokens(
replacement=[32000, 32000],
),
],
"pattern_4": [
PlaceholderInfo(
modality="pattern_4",
item_idx=0,
start_idx=5,
replacement=[32000],
),
],
"pattern_3": [
PlaceholderInfo(
modality="pattern_3",
Expand Down
89 changes: 39 additions & 50 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,73 +404,62 @@ def replace_text_matches(
return "".join(texts)


def _iter_modality_placeholders(
def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
modality: str,
modality_repls: Sequence[BoundPromptReplacement],
modal_item_count: int,
mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderInfo]:
if modal_item_count == 0:
return
"""
Yield each set of placeholder tokens found in :code:`prompt`.
Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_repls` takes priority.
Note that empty matches are ignored.
"""
prompt_len = len(prompt)
item_idx = 0
item_idx_by_modality = defaultdict[str, int](lambda: 0)

start_idx = 0
while start_idx < prompt_len:
found = False

for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len

if repl_len == 0 or end_idx > prompt_len:
for modality, modality_repls in mm_prompt_repls.items():
item_idx = item_idx_by_modality[modality]
if item_idx >= mm_item_counts.get(modality, 0):
continue

if prompt[start_idx:end_idx] == repl_tokens:
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
replacement=repl_tokens,
)
for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len

if repl_len == 0 or end_idx > prompt_len:
continue

if prompt[start_idx:end_idx] == repl_tokens:
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
replacement=repl_tokens,
)

item_idx += 1
if item_idx >= modal_item_count:
return
# Exclude overlapping matches
start_idx = end_idx
item_idx_by_modality[modality] += 1
found = True
break

# Exclude overlapping matches
start_idx = end_idx
found = True
break
if found:
break # Go back to the outer while loop

if not found:
start_idx += 1


def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderInfo]:
"""
For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
Note that empty matches are ignored.
"""
for modality, modal_item_count in mm_item_counts.items():
if modality in mm_prompt_repls:
yield from _iter_modality_placeholders(
prompt,
modality,
mm_prompt_repls[modality],
modal_item_count,
)


def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
Expand Down Expand Up @@ -1156,7 +1145,7 @@ def apply(

# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
if all(len(repls) == 0 for repls in mm_missing_repls.values()):
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
Expand Down
5 changes: 4 additions & 1 deletion vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def get_max_tokens_per_item_by_modality(
This is currently directly used only in V1.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len
return processor.info.get_mm_max_tokens_per_item(seq_len)
Expand Down

0 comments on commit bb354e6

Please sign in to comment.