From 5f4bbf7ea2e5cd4f41e47cc501a430d8b4791deb Mon Sep 17 00:00:00 2001 From: Matthias Vogler Date: Mon, 23 Dec 2024 14:17:59 +0100 Subject: [PATCH 1/9] added molmo lora Signed-off-by: Matthias Vogler --- vllm/model_executor/models/molmo.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 9f744b6918818..53b0329a9f4f2 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -15,7 +15,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -43,7 +43,7 @@ SequenceData) from vllm.transformers_utils.processor import get_processor -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -1121,9 +1121,19 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): +class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + supported_lora_modules = [ + "transformer.blocks.22.att_proj", + "transformer.blocks.22.ff_proj", + "transformer.blocks.23.att_proj", + "transformer.blocks.23.ff_proj", + "transformer.blocks.16.att_proj", + "transformer.blocks.16.ff_proj", + "transformer.blocks.8.att_proj", + "transformer.blocks.8.ff_proj", + "transformer.blocks.20.att_proj", + ] + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", lora_config: Optional[LoRAConfig] = None): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -1152,6 +1162,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + + self.lora_config = lora_config def _parse_and_validate_image_input( self, From 1ca7a966681958f2c0c9e80d4c3be067eb74212e Mon Sep 17 00:00:00 2001 From: Matthias Vogler Date: Mon, 23 Dec 2024 17:09:39 +0100 Subject: [PATCH 2/9] added Molmo Lora Support Signed-off-by: Matthias Vogler --- vllm/model_executor/models/molmo.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 53b0329a9f4f2..70876edb9e92a 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1122,17 +1122,18 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "att_proj": ["att_proj"], + "attn_out": ["attn_out"], + "ff_proj": ["ff_proj"], + "ff_out": ["ff_out"], + } supported_lora_modules = [ - "transformer.blocks.22.att_proj", - "transformer.blocks.22.ff_proj", - "transformer.blocks.23.att_proj", - "transformer.blocks.23.ff_proj", - "transformer.blocks.16.att_proj", - "transformer.blocks.16.ff_proj", - "transformer.blocks.8.att_proj", - "transformer.blocks.8.ff_proj", - "transformer.blocks.20.att_proj", + "att_proj", + "ff_proj", ] + embedding_modules = {} + embedding_padding_modules = {} def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", lora_config: Optional[LoRAConfig] = None): super().__init__() config = vllm_config.model_config.hf_config @@ -1164,6 +1165,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", lora_config: Op self.model.make_empty_intermediate_tensors) self.lora_config = lora_config + def _parse_and_validate_image_input( self, From 2bb466c1b05500944c6262cb433d9bef41443f2b Mon Sep 17 00:00:00 2001 From: Matthias Vogler Date: Tue, 24 Dec 2024 09:36:04 +0100 Subject: [PATCH 3/9] format code, edit docs Signed-off-by: Matthias Vogler --- docs/source/models/supported_models.md | 2 +- vllm/model_executor/models/molmo.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 650293d864011..4615960fb9dec 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -660,7 +660,7 @@ See [this page](#generative-models) for more information on how to use generativ - Molmo - T + I - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. - - + - ✅︎ - ✅︎ - ✅︎ * - :code:`NVLM_D_Model` diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 70876edb9e92a..9b9a2310040f2 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1134,7 +1134,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ] embedding_modules = {} embedding_padding_modules = {} - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", lora_config: Optional[LoRAConfig] = None): + def __init__( + self, *, + vllm_config: VllmConfig, + prefix: str = "", + lora_config: Optional[LoRAConfig] = None): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config From 037bec88ac3a751a11c443e99bc1cff7bc995452 Mon Sep 17 00:00:00 2001 From: Matthias Vogler Date: Tue, 24 Dec 2024 10:03:17 +0100 Subject: [PATCH 4/9] format code Signed-off-by: Matthias Vogler --- vllm/model_executor/models/molmo.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 9b9a2310040f2..0eb4a77d31faf 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -15,7 +15,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -43,7 +43,7 @@ SequenceData) from vllm.transformers_utils.processor import get_processor -from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -1121,7 +1121,8 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): +class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): packed_modules_mapping = { "att_proj": ["att_proj"], "attn_out": ["attn_out"], @@ -1134,11 +1135,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ] embedding_modules = {} embedding_padding_modules = {} - def __init__( - self, *, - vllm_config: VllmConfig, - prefix: str = "", - lora_config: Optional[LoRAConfig] = None): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + lora_config: Optional[LoRAConfig] = None): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -1167,9 +1169,8 @@ def __init__( self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - + self.lora_config = lora_config - def _parse_and_validate_image_input( self, From 39527376d34764e24f34e32682d02d3b81394820 Mon Sep 17 00:00:00 2001 From: Matthias Vogler Date: Thu, 26 Dec 2024 13:19:22 +0100 Subject: [PATCH 5/9] obtain lora_config directly from vllm_config Signed-off-by: Matthias Vogler --- vllm/model_executor/models/molmo.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 12ed362441dd2..96bab45225731 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1164,17 +1164,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, embedding_modules = {} embedding_padding_modules = {} - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - lora_config: Optional[LoRAConfig] = None): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config + self.lora_config = vllm_config.lora_config vision_config = VisionBackboneConfig() self.vision_backbone = MolmoVisionBackbone(config, vision_config, @@ -1198,8 +1195,6 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self.lora_config = lora_config - def _parse_and_validate_image_input( self, **kwargs: object, From c81e8685a50a9dac751bb9006d4f9ce988aee933 Mon Sep 17 00:00:00 2001 From: Matthias Vogler Date: Thu, 26 Dec 2024 13:37:05 +0100 Subject: [PATCH 6/9] format code --- vllm/model_executor/models/molmo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 96bab45225731..0ac0f1de1f15c 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -15,7 +15,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, From fa01a174abfa4a0fab53e7dbab3e12d03c533f86 Mon Sep 17 00:00:00 2001 From: Matthias Vogler Date: Thu, 26 Dec 2024 13:42:44 +0100 Subject: [PATCH 7/9] clean up code --- vllm/model_executor/models/molmo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 0ac0f1de1f15c..d393b60869a79 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1169,9 +1169,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + lora_config = vllm_config.lora_config self.config = config self.multimodal_config = multimodal_config - self.lora_config = vllm_config.lora_config + self.lora_config = lora_config vision_config = VisionBackboneConfig() self.vision_backbone = MolmoVisionBackbone(config, vision_config, From b2c16c7a799c6a89ffac95a2410b147b5394ac39 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 27 Dec 2024 02:12:16 +0000 Subject: [PATCH 8/9] Add modules Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 38 +++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 019ccc940febe..13775356b3878 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer @@ -1158,17 +1159,30 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ) packed_modules_mapping = { - "att_proj": ["att_proj"], - "attn_out": ["attn_out"], - "ff_proj": ["ff_proj"], - "ff_out": ["ff_out"], + "qkv_proj": ["qkv_proj"], + "gate_up_proj": ["gate_up_proj"], # language model + "merged_linear": ["gate_proj", "up_proj"] # image_projector } + + # LoRA specific attributes supported_lora_modules = [ - "att_proj", - "ff_proj", + # language model + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", # same name with image_projector + # vision tower + "wq", + "wk", + "wv", + "wo", + "w1", + "w2", + # image_projector + "merged_linear", ] embedding_modules = {} - embedding_padding_modules = {} + embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1352,6 +1366,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weights = _get_weights_with_merged_embedding(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model", + connector="vision_backbone.image_projector", + tower_model="vision_backbone", + ) + def _get_weights_with_merged_embedding( weights: Iterable[Tuple[str, torch.Tensor]] From 789c8883e34cb6441f06833273fb2316aa7c9246 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 27 Dec 2024 02:19:23 +0000 Subject: [PATCH 9/9] format Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 13775356b3878..684ac806b299f 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1170,7 +1170,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, "qkv_proj", "o_proj", "gate_up_proj", - "down_proj", # same name with image_projector + "down_proj", # same name with image_projector # vision tower "wq", "wk",