Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aqiao committed Dec 18, 2024
1 parent 94a867b commit 787708a
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 40 deletions.
Empty file.
107 changes: 107 additions & 0 deletions tests/models/encoder_decoder/audio/test_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Compare the outputs of HF and vLLM for Whisper models using greedy sampling.
Run `pytest tests/models/encoder_decoder/audio/test_whisper.py`.
"""
from typing import Optional

import pytest

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset

from ....utils import fork_new_process_for_each_test, multi_gpu_test


PROMPTS = [
{
"prompt":
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt":
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
}
]

EXPECTED = {
"openai/whisper-medium": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its fleece was quite as"
" slow, and everywhere that Mary went the lamb was sure to go.",
" And the old one pitch on the way to Edgar Martinez swung on the line"
" down the left field line for Obeysmith. Here comes Joy. Here is"
" Jorgen at third base. They're gonna wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh"
" my."
],
"openai/whisper-large-v3": [
" The first words I spoke in the original phonograph. A little piece"
" of practical poetry. Mary had a little lamb, its fleece was white as"
" snow, and everywhere that Mary went, the lamb was sure to go.",
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line,"
" down the left field line for a base hit. Here comes Joy. Here is"
" Junior to third base. They're going to wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh,"
" my."
]
}


def run_test(
model: str,
*,
enforce_eager: bool,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
) -> None:
prompts = PROMPTS * 10
expected = EXPECTED[model] * 10

llm = LLM(
model=model,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=enforce_eager,
)

sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=200,
)

outputs = llm.generate(prompts, sampling_params)

for output, expected in zip(outputs, expected):

Check failure on line 86 in tests/models/encoder_decoder/audio/test_whisper.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (B020)

tests/models/encoder_decoder/audio/test_whisper.py:86:17: B020 Loop control variable `expected` overrides iterable it iterates
print(output.outputs[0].text)
assert output.outputs[0].text == expected


@fork_new_process_for_each_test
@pytest.mark.parametrize(
"model", ["openai/whisper-medium", "openai/whisper-large-v3"]
)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_models(model, enforce_eager) -> None:
run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=1)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", ["openai/whisper-large-v3"])
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
def test_models_distributed(model, enforce_eager,
distributed_executor_backend) -> None:
run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend)
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,7 @@ def _get_and_verify_max_len(
# Command-R
"model_max_length",
# Others
"max_length",
"max_sequence_length",
"max_seq_length",
"seq_len",
Expand Down
83 changes: 43 additions & 40 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@
logger = init_logger(__name__)


def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
def sinusoids(
length: int, channels: int, max_timescale: float = 10000
) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
f"Number of channels has to be divisible by 2 for sinusoidal "
f"positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
inv_timescales = torch.exp(-log_timescale_increment *
torch.arange(channels // 2))
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)

Expand Down Expand Up @@ -269,10 +273,12 @@ def forward(
hidden_states = residual + hidden_states

if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
torch.isinf(hidden_states).any() or
torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
hidden_states = torch.clamp(hidden_states, min=-clamp_value,
max=clamp_value)

return hidden_states

Expand Down Expand Up @@ -366,11 +372,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.num_mel_bins = config.num_mel_bins
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_scale = (
math.sqrt(embed_dim) if config.scale_embedding else 1.0)

self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3,
padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2,
padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
Expand All @@ -380,7 +390,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.layer_norm = nn.LayerNorm(config.d_model)

with torch.no_grad():
self.embed_positions.weight.copy_(sinusoids(*self.embed_positions.weight.shape))
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape))

def forward(
self,
Expand Down Expand Up @@ -417,10 +428,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_target_positions
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_scale = (
math.sqrt(config.d_model) if config.scale_embedding else 1.0)

self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model,
self.padding_idx)
self.embed_positions = WhisperPositionalEmbedding(
self.max_target_positions, config.d_model)
self.start_layer, self.end_layer, self.layers = make_layers(
config.decoder_layers,
lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config,
Expand Down Expand Up @@ -463,7 +477,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

def forward(
self,
input_features: Optional[torch.FloatTensor],
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
Expand Down Expand Up @@ -550,32 +564,16 @@ def get_whisper_processor(
**kwargs,
) -> WhisperProcessor:
"""Gets an whisper processor for the given model name via HuggingFace."""
try:
processor: WhisperProcessor = WhisperProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the whisper processor. If the whisper processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e

return processor
return WhisperProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)


def input_processor_for_whisper(ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
def input_processor_for_whisper(ctx: InputContext, inputs):
multi_modal_data = inputs["encoder"]["multi_modal_data"]
if isinstance(multi_modal_data["audio"], list):
assert len(multi_modal_data["audio"]) == 1
Expand Down Expand Up @@ -625,7 +623,8 @@ def input_mapper_for_whisper(
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_max_whisper_audio_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand All @@ -634,6 +633,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config

Check failure on line 634 in vllm/model_executor/models/whisper.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/model_executor/models/whisper.py:634:9: F841 Local variable `multimodal_config` is assigned to but never used
self.config = config
self.dtype = vllm_config.model_config.dtype

self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
self.unpadded_vocab_size = config.vocab_size
Expand All @@ -655,7 +655,10 @@ def forward(
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]]
input_features = kwargs.get("input_features")
if input_features is not None:
input_features = [feat.to(self.dtype) for feat in input_features]
decoder_outputs = self.model(
input_features=input_features,
input_ids=input_ids,
Expand Down

0 comments on commit 787708a

Please sign in to comment.