diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 650293d864011..545a2ccaa5634 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -28,7 +28,7 @@ llm = LLM(model=..., task="generate") # Name or path of your model output = llm.generate("Hello, my name is") print(output) -# For pooling models (task={embed,classify,reward}) only +# For pooling models (task={embed,classify,reward,score}) only llm = LLM(model=..., task="embed") # Name or path of your model output = llm.encode("Hello, my name is") print(output) @@ -59,7 +59,7 @@ llm = LLM(model=..., revision=..., task=..., trust_remote_code=True) output = llm.generate("Hello, my name is") print(output) -# For pooling models (task={embed,classify,reward}) only +# For pooling models (task={embed,classify,reward,score}) only output = llm.encode("Hello, my name is") print(output) ``` @@ -369,14 +369,6 @@ you should explicitly specify the task type to ensure that the model is used in #### Text Embedding (`--task embed`) -Any text generation model can be converted into an embedding model by passing {code}`--task embed`. - -```{note} -To get the best results, you should use pooling models that are specifically trained as such. -``` - -The following table lists those that are tested in vLLM. - ```{eval-rst} .. list-table:: :widths: 25 25 50 5 5 @@ -437,6 +429,10 @@ On the other hand, its 1.5B variant ({code}`Alibaba-NLP/gte-Qwen2-1.5B-instruct` despite being described otherwise on its model card. ``` +If your model is not in the above list, we will try to automatically convert the model using +:func:`vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings +of the whole prompt are extracted from the normalized hidden state corresponding to the last token. + #### Reward Modeling (`--task reward`) ```{eval-rst} @@ -461,6 +457,9 @@ despite being described otherwise on its model card. - ✅︎ ``` +If your model is not in the above list, we will try to automatically convert the model using +:func:`vllm.model_executor.models.adapters.as_reward_model`. By default, we return the hidden states of each token directly. + ```{important} For process-supervised reward models such as {code}`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. @@ -490,6 +489,9 @@ e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 1 - ✅︎ ``` +If your model is not in the above list, we will try to automatically convert the model using +:func:`vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. + #### Sentence Pair Scoring (`--task score`) ```{eval-rst} diff --git a/tests/models/embedding/language/test_cls_models.py b/tests/models/embedding/language/test_cls_models.py index 6321503e7b248..6673a9fc22f69 100644 --- a/tests/models/embedding/language/test_cls_models.py +++ b/tests/models/embedding/language/test_cls_models.py @@ -1,7 +1,4 @@ -"""Compare the outputs of HF and vLLM when using greedy sampling. - -This test only tests small models. Big models such as 7B should be tested from -test_big_models.py because it could use a larger instance to run tests. +"""Compare the classification outputs of HF and vLLM models. Run `pytest tests/models/test_cls_models.py`. """ diff --git a/tests/models/embedding/language/test_scoring.py b/tests/models/embedding/language/test_scoring.py index af31e1a635f65..be6e3842821e2 100644 --- a/tests/models/embedding/language/test_scoring.py +++ b/tests/models/embedding/language/test_scoring.py @@ -1,6 +1,6 @@ -"""Compare the embedding outputs of HF and vLLM models. +"""Compare the scoring outputs of HF and vLLM models. -Run `pytest tests/models/embedding/language/test_embedding.py`. +Run `pytest tests/models/embedding/language/test_scoring.py`. """ import math diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b5368aab3ecf1..73b70d65e8e0b 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -6,7 +6,9 @@ from vllm.model_executor.models import (is_pooling_model, is_text_generation_model, supports_multimodal) -from vllm.model_executor.models.adapters import as_embedding_model +from vllm.model_executor.models.adapters import (as_classification_model, + as_embedding_model, + as_reward_model) from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, _TEXT_GENERATION_MODELS, @@ -29,9 +31,10 @@ def test_registry_imports(model_arch): or model_arch in _MULTIMODAL_MODELS): assert is_text_generation_model(model_cls) - # All vLLM models should be convertible to an embedding model - embed_model = as_embedding_model(model_cls) - assert is_pooling_model(embed_model) + # All vLLM models should be convertible to a pooling model + assert is_pooling_model(as_classification_model(model_cls)) + assert is_pooling_model(as_embedding_model(model_cls)) + assert is_pooling_model(as_reward_model(model_cls)) if model_arch in _MULTIMODAL_MODELS: assert supports_multimodal(model_cls) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f15e7176b3d50..44978a55e072d 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -7,7 +7,9 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.models.adapters import as_embedding_model +from vllm.model_executor.models.adapters import (as_classification_model, + as_embedding_model, + as_reward_model) @contextlib.contextmanager @@ -35,8 +37,12 @@ def get_model_architecture( architectures = ["QuantMixtralForCausalLM"] model_cls, arch = ModelRegistry.resolve_model_cls(architectures) - if model_config.runner_type == "pooling": + if model_config.task == "embed": model_cls = as_embedding_model(model_cls) + elif model_config.task == "classify": + model_cls = as_classification_model(model_cls) + elif model_config.task == "reward": + model_cls = as_reward_model(model_cls) return model_cls, arch diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 9cc43ae9181b9..55e90b9d41950 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -1,29 +1,48 @@ from collections.abc import Iterable -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar import torch import torch.nn as nn from .interfaces_base import VllmModelForPooling, is_pooling_model +if TYPE_CHECKING: + from vllm.model_executor.layers.pooler import PoolingType + _T = TypeVar("_T", bound=type[nn.Module]) +_GENERATE_SUFFIXES = [ + "ForCausalLM", + "ForConditionalGeneration", + "ChatModel", + "LMHeadModel", +] -def as_embedding_model(cls: _T) -> _T: - """Subclass an existing vLLM model to support embeddings.""" - # Avoid modifying existing embedding models - if is_pooling_model(cls): - return cls +def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: + model_name = orig_model_name + + for generate_suffix in _GENERATE_SUFFIXES: + model_name = model_name.removesuffix(generate_suffix) + + return model_name + pooling_suffix + + +def _create_pooling_model_cls( + orig_cls: _T, + *, + default_pooling_type: "PoolingType", + default_normalize: bool, + default_softmax: bool, +) -> _T: # Lazy import from vllm.config import VllmConfig - from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput, - PoolingType) + from vllm.model_executor.layers.pooler import Pooler, PoolerOutput from vllm.model_executor.pooling_metadata import PoolingMetadata from .utils import AutoWeightsLoader, WeightsMapper - class ModelForEmbedding(cls, VllmModelForPooling): + class ModelForPooling(orig_cls, VllmModelForPooling): def __init__( self, @@ -34,7 +53,7 @@ def __init__( ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - # These are not used in embedding models + # These are not used in pooling models for attr in ("lm_head", "logits_processor"): if hasattr(self, attr): delattr(self, attr) @@ -46,9 +65,9 @@ def __init__( if not getattr(self, "_pooler", None): self._pooler = Pooler.from_config_with_defaults( pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, ) def pooler( @@ -82,17 +101,148 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): return # For most other models - if hasattr(cls, "load_weights"): - cls.load_weights(self, weights) # type: ignore + if hasattr(orig_cls, "load_weights"): + orig_cls.load_weights(self, weights) # type: ignore # Fallback else: loader = AutoWeightsLoader(self) loader.load_weights(weights) - ModelForEmbedding.__name__ = cls.__name__ \ - .removesuffix("ForCausalLM") \ - .removesuffix("ForConditionalGeneration") \ - .removesuffix("ChatModel") \ - .removesuffix("LMHeadModel") + "ForEmbedding" + return ModelForPooling # type: ignore + + +def as_embedding_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support embeddings. + + By default, the embeddings of the whole prompt are extracted from the + normalized hidden state corresponding to the last token. + + Note: + We assume that no extra layers are added to the original model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing embedding models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.model_executor.layers.pooler import PoolingType + + ModelForEmbedding = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.LAST, + default_normalize=True, + default_softmax=False, + ) + ModelForEmbedding.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForEmbedding") return ModelForEmbedding # type: ignore + + +def as_classification_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support classification. + + By default, the class probabilities are extracted from the softmaxed + hidden state corresponding to the last token. + + Note: + We assume that the classification head is a single linear layer + stored as the attribute `score` of the top-level model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing classification models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.attention import AttentionMetadata + from vllm.config import VllmConfig + from vllm.model_executor.layers.linear import RowParallelLinear + from vllm.model_executor.layers.pooler import PoolingType + from vllm.sequence import IntermediateTensors + + from .utils import maybe_prefix + + ModelForPooling = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.LAST, + default_normalize=False, + default_softmax=True, + ) + + class ModelForClassification(ModelForPooling): + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.score = RowParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix( + prefix, "score")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: list[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = super().forward(input_ids, positions, kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds) + logits, _ = self.score(hidden_states) + return logits + + + ModelForClassification.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForClassification") + + return ModelForClassification # type: ignore + + +def as_reward_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support reward modeling. + + By default, we return the hidden states of each token directly. + + Note: + We assume that no extra layers are added to the original model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing reward models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.model_executor.layers.pooler import PoolingType + + ModelForReward = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.ALL, + default_normalize=False, + default_softmax=False, + ) + + ModelForReward.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForReward") + + return ModelForReward # type: ignore diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7661bb285df95..88f4ea4352726 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -545,8 +545,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - # TODO: Replace this model class with for_embedding(Qwen2ForCausalLM), - # after changing the default pooling method + # TODO: Replace this model class with as_embedding_model( + # Qwen2ForCausalLM) after changing the default pooling method if pooler_config.pooling_type is None: logger.warning( "This embedding model will default to last-token pooling in " diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py deleted file mode 100644 index dc5dabf6fc38b..0000000000000 --- a/vllm/model_executor/models/qwen2_cls.py +++ /dev/null @@ -1,104 +0,0 @@ -# Adapted from -# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py -# Copyright 2024 Kakao Corp. (Kanana-X Team) -# Copyright 2024 The Qwen team. -# Copyright 2023 The vLLM team. -"""Inference-only Qwen2-Classification model compatible with HF weights.""" -from typing import Iterable, List, Optional, Set, Tuple - -import torch -from torch import nn - -from vllm.attention import AttentionMetadata -from vllm.config import VllmConfig -from vllm.model_executor.layers.linear import RowParallelLinear -from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput - -from .interfaces import SupportsLoRA, SupportsPP -from .utils import AutoWeightsLoader, maybe_prefix - - -class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - # hidden_states from Qwen2Model has been reduced, - # the input of score layer is not parallelized. - self.score = RowParallelLinear(config.hidden_size, - config.num_labels, - quant_config=quant_config, - input_is_parallel=False, - bias=False, - prefix=maybe_prefix(prefix, "score")) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=True) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - logits, _ = self.score(hidden_states) - return logits - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader(self, - ignore_unexpected_prefixes=["lm_head."]) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 04d806c3c7eae..b32a3421d5841 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -20,11 +20,10 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from .adapters import as_embedding_model from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, supports_pp) -from .interfaces_base import is_pooling_model, is_text_generation_model +from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -125,12 +124,13 @@ "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), - "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501 "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 + # [Auto-converted (see adapters.py)] + "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"), } _CROSS_ENCODER_MODELS = { @@ -226,19 +226,10 @@ class _ModelInfo: @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": - is_pooling_model_ = is_pooling_model(model) - if not is_pooling_model_: - try: - as_embedding_model(model) - except Exception: - pass - else: - is_pooling_model_ = True - return _ModelInfo( architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), - is_pooling_model=is_pooling_model_, + is_pooling_model=True, # Can convert any model into a pooling model supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model),