From f7296142620dff629cc9f036cda910cfec924996 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 15 Dec 2024 04:42:27 +0000 Subject: [PATCH 1/2] Clean up multimodal processor Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 17 +++++----- vllm/multimodal/processing.py | 48 ++++++++++++++--------------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 6aaa80ddc9fa5..d22d778f81fa8 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -2,10 +2,9 @@ import pytest -from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement, - _PlaceholderInfo, find_text_matches, - find_token_matches, iter_placeholders, - iter_token_matches, +from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, + find_text_matches, find_token_matches, + iter_placeholders, iter_token_matches, replace_text_matches, replace_token_matches) from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -314,8 +313,8 @@ def test_find_replace_text( result = replace_text_matches( prompt, matches, - MultiModalDataItems({key: [None] * mm_count - for key in repl_by_key}), + {key: mm_count + for key in repl_by_key}, ) # Only displayed on error @@ -380,8 +379,8 @@ def test_find_replace_tokens( result = replace_token_matches( prompt, matches, - MultiModalDataItems({key: [None] * mm_count - for key in repl_by_key}), + {key: mm_count + for key in repl_by_key}, ) # Only displayed on error @@ -476,7 +475,7 @@ def test_iter_placeholders( prompt_repls, prompt, # Effectively match all occurrences in the prompt - MultiModalDataItems({key: [None] * 3 for key in repl_by_key}), + {key: 3 for key in repl_by_key}, )) # Only displayed on error diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index de5a002d474c2..ce6bec1d49aac 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -403,18 +403,17 @@ def _resolve_matches( def _replace_matches( prompt: _S, matches: Sequence[_PromptReplacementMatch], - mm_items: MultiModalDataItems, + mm_item_counts: Mapping[str, int], ) -> list[_S]: out_seqs = list[_S]() prev_end_idx = 0 - next_idx_by_modality = {modality: 0 for modality in mm_items} + next_idx_by_modality = {modality: 0 for modality in mm_item_counts} for match in _resolve_matches(prompt, matches): modality = match.modality - modal_items = mm_items[modality] item_idx = next_idx_by_modality[modality] - if item_idx >= len(modal_items): + if item_idx >= mm_item_counts[modality]: continue start_idx = match.start_idx @@ -441,13 +440,13 @@ def _replace_matches( def replace_token_matches( prompt: list[int], matches: Sequence[_PromptReplacementTokenMatch], - mm_items: MultiModalDataItems, + mm_item_counts: Mapping[str, int], ) -> list[int]: """Apply :code:`prompt_repls` to :code:`prompt`.""" if not matches: return prompt - token_id_seqs = _replace_matches(prompt, matches, mm_items) + token_id_seqs = _replace_matches(prompt, matches, mm_item_counts) return flatten_2d_lists(token_id_seqs) @@ -455,13 +454,13 @@ def replace_token_matches( def replace_text_matches( prompt: str, matches: Sequence[_PromptReplacementTextMatch], - mm_items: MultiModalDataItems, + mm_item_counts: Mapping[str, int], ) -> str: """Apply :code:`prompt_repls` to :code:`prompt`.""" if not matches: return prompt - texts = _replace_matches(prompt, matches, mm_items) + texts = _replace_matches(prompt, matches, mm_item_counts) return "".join(texts) @@ -470,9 +469,9 @@ def _iter_modality_placeholders( prompt: list[int], modality: str, modality_repls: Sequence[_BoundPromptReplacement], - modal_items: list[Any], + modal_item_count: int, ) -> Iterable[_PlaceholderInfo]: - if len(modal_items) == 0: + if modal_item_count == 0: return prompt_len = len(prompt) @@ -499,7 +498,7 @@ def _iter_modality_placeholders( ) item_index += 1 - if item_index >= len(modal_items): + if item_index >= modal_item_count: return # Exclude overlapping matches @@ -514,7 +513,7 @@ def _iter_modality_placeholders( def iter_placeholders( prompt_repls: Sequence[_BoundPromptReplacement], prompt: list[int], - mm_items: MultiModalDataItems, + mm_item_counts: Mapping[str, int], ) -> Iterable[_PlaceholderInfo]: """ Yield each set of placeholder tokens found in :code:`prompt`. @@ -523,13 +522,13 @@ def iter_placeholders( """ repls_by_modality = dict(full_groupby_modality(prompt_repls)) - for modality, modal_items in mm_items.items(): + for modality, modal_item_count in mm_item_counts.items(): if modality in repls_by_modality: yield from _iter_modality_placeholders( prompt, modality, repls_by_modality[modality], - modal_items, + modal_item_count, ) @@ -590,10 +589,10 @@ def _find_placeholders( self, all_prompt_repls: Sequence[_BoundPromptReplacement], new_token_ids: list[int], - mm_items: MultiModalDataItems, + mm_item_counts: Mapping[str, int], ) -> list[_PlaceholderInfo]: return list( - iter_placeholders(all_prompt_repls, new_token_ids, mm_items)) + iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) def _apply_hf_processor( self, @@ -655,10 +654,9 @@ def _bind_prompt_replacements( def _apply_prompt_replacements( self, - mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, token_ids: list[int], prompt_repls: Sequence[_BoundPromptReplacement], + mm_item_counts: Mapping[str, int], ) -> tuple[list[int], str, list[_PlaceholderInfo]]: tokenizer = self._get_tokenizer() @@ -675,13 +673,13 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= len(mm_items[modality]) + len(matches) >= mm_item_counts[modality] for modality, matches in full_groupby_modality(token_matches) ): # yapf: disable token_ids = replace_token_matches( token_ids, token_matches, - mm_items, + mm_item_counts, ) text = _decode(tokenizer, token_ids) @@ -693,14 +691,14 @@ def _apply_prompt_replacements( text = replace_text_matches( text, text_matches, - mm_items, + mm_item_counts, ) token_ids = _encode(tokenizer, text) matched_repls = [match.prompt_repl for match in text_matches] placeholders = self._find_placeholders(matched_repls, token_ids, - mm_items) + mm_item_counts) return token_ids, text, placeholders @@ -737,8 +735,9 @@ def apply( # If HF processor already inserts placeholder tokens, # there is no need for us to insert them + mm_item_counts = {m: len(items) for m, items in mm_items.items()} all_placeholders = self._find_placeholders(all_prompt_repls, - prompt_ids, mm_items) + prompt_ids, mm_item_counts) if all_placeholders: prompt_text = _decode(tokenizer, prompt_ids) @@ -748,10 +747,9 @@ def apply( prompt_text, all_placeholders, ) = self._apply_prompt_replacements( - mm_items, - hf_inputs, prompt_ids, all_prompt_repls, + mm_item_counts, ) mm_placeholders = { From 68e0fcd0fb7e8c2bfde6b87ac3821053417489a6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 15 Dec 2024 04:42:34 +0000 Subject: [PATCH 2/2] Remove outdated code Signed-off-by: DarkLight1337 --- examples/offline_inference_vision_language.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 45539c665a922..7bc43242b717e 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -92,10 +92,7 @@ def run_fuyu(question: str, modality: str): def run_phi3v(question: str, modality: str): assert modality == "image" - prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501 - # Note: The default setting of max_num_seqs (256) and - # max_model_len (128k) for this model may cause OOM. - # You may lower either to run this example on lower-end GPUs. + prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # num_crops is an override kwarg to the multimodal image processor; # For some models, e.g., Phi-3.5-vision-instruct, it is recommended