diff --git a/tests/models/encoder_decoder/audio/__init__.py b/tests/models/encoder_decoder/audio/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio/test_whisper.py new file mode 100644 index 0000000000000..6ddbf8f579a07 --- /dev/null +++ b/tests/models/encoder_decoder/audio/test_whisper.py @@ -0,0 +1,107 @@ +"""Compare the outputs of HF and vLLM for Whisper models using greedy sampling. + +Run `pytest tests/models/encoder_decoder/audio/test_whisper.py`. +""" +from typing import Optional + +import pytest + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset + +from ....utils import fork_new_process_for_each_test, multi_gpu_test + + +PROMPTS = [ + { + "prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + "multi_modal_data": { + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + }, + }, + { # Test explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": AudioAsset("winning_call").audio_and_sample_rate, + }, + }, + "decoder_prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + } +] + +EXPECTED = { + "openai/whisper-medium": [ + " The first words I spoke in the original phonograph, a little piece" + " of practical poetry. Mary had a little lamb, its fleece was quite as" + " slow, and everywhere that Mary went the lamb was sure to go.", + " And the old one pitch on the way to Edgar Martinez swung on the line" + " down the left field line for Obeysmith. Here comes Joy. Here is" + " Jorgen at third base. They're gonna wave him in. The throw to the" + " plate will be late. The Mariners are going to play for the American" + " League Championship. I don't believe it. It just continues. My, oh" + " my." + ], + "openai/whisper-large-v3": [ + " The first words I spoke in the original phonograph. A little piece" + " of practical poetry. Mary had a little lamb, its fleece was white as" + " snow, and everywhere that Mary went, the lamb was sure to go.", + " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line," + " down the left field line for a base hit. Here comes Joy. Here is" + " Junior to third base. They're going to wave him in. The throw to the" + " plate will be late. The Mariners are going to play for the American" + " League Championship. I don't believe it. It just continues. My, oh," + " my." + ] +} + + +def run_test( + model: str, + *, + enforce_eager: bool, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +) -> None: + prompts = PROMPTS * 10 + expected = EXPECTED[model] * 10 + + llm = LLM( + model=model, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=enforce_eager, + ) + + sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + max_tokens=200, + ) + + outputs = llm.generate(prompts, sampling_params) + + for output, expected in zip(outputs, expected): + print(output.outputs[0].text) + assert output.outputs[0].text == expected + + +@fork_new_process_for_each_test +@pytest.mark.parametrize( + "model", ["openai/whisper-medium", "openai/whisper-large-v3"] +) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_models(model, enforce_eager) -> None: + run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=1) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +def test_models_distributed(model, enforce_eager, + distributed_executor_backend) -> None: + run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend) diff --git a/vllm/config.py b/vllm/config.py index 08a7b607630af..d481c8c31b5a3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1961,6 +1961,7 @@ def _get_and_verify_max_len( # Command-R "model_max_length", # Others + "max_length", "max_sequence_length", "max_seq_length", "seq_len", diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0c45af884395d..ad9c1a7d71b5f 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -37,14 +37,18 @@ logger = init_logger(__name__) -def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: +def sinusoids( + length: int, channels: int, max_timescale: float = 10000 +) -> torch.Tensor: """Returns sinusoids for positional embedding""" if channels % 2 != 0: raise ValueError( - f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + f"Number of channels has to be divisible by 2 for sinusoidal " + f"positional embeddings, got {channels} channels." ) log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + inv_timescales = torch.exp(-log_timescale_increment * + torch.arange(channels // 2)) scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) @@ -269,10 +273,12 @@ def forward( hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + torch.isinf(hidden_states).any() or + torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, + max=clamp_value) return hidden_states @@ -366,11 +372,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_mel_bins = config.num_mel_bins self.padding_idx = config.pad_token_id self.max_source_positions = config.max_source_positions - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - - self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) - self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) - self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_scale = ( + math.sqrt(embed_dim) if config.scale_embedding else 1.0) + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, + padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, + padding=1) + self.embed_positions = nn.Embedding(self.max_source_positions, + embed_dim) self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, @@ -380,7 +390,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.layer_norm = nn.LayerNorm(config.d_model) with torch.no_grad(): - self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape)) + self.embed_positions.weight.copy_( + sinusoids(*self.embed_positions.weight.shape)) def forward( self, @@ -417,10 +428,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.max_target_positions = config.max_target_positions self.max_source_positions = config.max_source_positions - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.embed_scale = ( + math.sqrt(config.d_model) if config.scale_embedding else 1.0) - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) - self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, + self.padding_idx) + self.embed_positions = WhisperPositionalEmbedding( + self.max_target_positions, config.d_model) self.start_layer, self.end_layer, self.layers = make_layers( config.decoder_layers, lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config, @@ -463,7 +477,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, - input_features: Optional[torch.FloatTensor], + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], @@ -550,32 +564,16 @@ def get_whisper_processor( **kwargs, ) -> WhisperProcessor: """Gets an whisper processor for the given model name via HuggingFace.""" - try: - processor: WhisperProcessor = WhisperProcessor.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs) - except ValueError as e: - # If the error pertains to the processor class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors - if not trust_remote_code: - err_msg = ( - "Failed to load the whisper processor. If the whisper processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - - return processor + return WhisperProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) -def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: +def input_processor_for_whisper(ctx: InputContext, inputs): multi_modal_data = inputs["encoder"]["multi_modal_data"] if isinstance(multi_modal_data["audio"], list): assert len(multi_modal_data["audio"]) == 1 @@ -625,7 +623,8 @@ def input_mapper_for_whisper( @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) @INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_max_whisper_audio_tokens) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_max_whisper_audio_tokens) class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -634,6 +633,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config + self.dtype = vllm_config.model_config.dtype self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) self.unpadded_vocab_size = config.vocab_size @@ -655,7 +655,10 @@ def forward( attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]] input_features = kwargs.get("input_features") + if input_features is not None: + input_features = [feat.to(self.dtype) for feat in input_features] decoder_outputs = self.model( input_features=input_features, input_ids=input_ids,