Skip to content

Commit

Permalink
fix(test): use tiny model for audio transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
g-prz committed Dec 4, 2024
1 parent 31cc467 commit 210ac07
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 26 deletions.
19 changes: 7 additions & 12 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,23 +691,18 @@ def _validate_prompt_media_types(
def valid_types(prompts, media):
import numpy as np # type: ignore

if isinstance(prompts, list):
if not isinstance(media, list) or len(prompts) != len(media):
return False
for subprompt, submedia in zip(prompts, media):
if not isinstance(subprompt, str) or not all(
isinstance(m, np.ndarray) for m in submedia
):
return False
elif isinstance(prompts, str):
if not all(isinstance(m, np.ndarray) for m in media):
return False
if not isinstance(prompts, (str, list)):
return False
if not isinstance(media, list):
return False
if not all(isinstance(m, np.ndarray) for m in media):
return False
return True

if not valid_types(prompts, media):
raise TypeError(
"Expected (prompts, media) to be of type "
"(str, List[np.ndarray])), or (List[str], List[List[np.ndarray]]) "
"(str, List[np.ndarray])), or (List[str], List[np.ndarray]]) "
f"instead got prompts={prompts}, media={media}"
)

Expand Down
6 changes: 5 additions & 1 deletion outlines/models/transformers_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def generate( # type: ignore
The generated text
"""
inputs = self.processor(
text=prompts, audios=media, padding=True, return_tensors="pt"
text=prompts,
audios=media,
padding=True,
return_tensors="pt",
sampling_rate=self.processor.feature_extractor.sampling_rate,
).to(self.model.device)

generation_kwargs = self._get_generation_kwargs(
Expand Down
7 changes: 4 additions & 3 deletions tests/generate/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def test_vision_sequence_generator_validate_types(prompts, media, type_error):
"prompts,media,type_error",
[
("single prompt", [AUDIO_ARRAY], False),
(["prompt0", "prompt1"], [[AUDIO_ARRAY], [AUDIO_ARRAY]], False),
(["single prompt"], [AUDIO_ARRAY], False),
(["prompt0", "prompt1"], [AUDIO_ARRAY, AUDIO_ARRAY], False),
("single prompt", [AUDIO_ARRAY, AUDIO_ARRAY], False),
(["prompt0", "prompt1"], [[AUDIO_ARRAY, AUDIO_ARRAY], [AUDIO_ARRAY]], False),
("single prompt", "this isn't an audio, it's a string", True),
("single prompt", AUDIO_ARRAY, True),
(["prompt0", "prompt1"], [AUDIO_ARRAY], True),
(["prompt0", "prompt1"], [AUDIO_ARRAY], False),
("prompt0", [[AUDIO_ARRAY]], True),
(["prompt0", "prompt1"], [[AUDIO_ARRAY]], True),
(["prompt0", "prompt1"], [[[AUDIO_ARRAY]], [[AUDIO_ARRAY]]], True),
],
Expand Down
27 changes: 20 additions & 7 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def model_transformers_audio(tmp_path_factory):
from transformers import Qwen2AudioForConditionalGeneration

return models.transformers_audio(
"Qwen/Qwen2-Audio-7B-Instruct",
"yujiepan/qwen2-audio-tiny-random",
model_class=Qwen2AudioForConditionalGeneration,
device="cpu",
)
Expand Down Expand Up @@ -197,7 +197,12 @@ def enforce_not_implemented(model_fixture, *task_names):
"model_transformers_audio",
],
"batch": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"],
"beam_search": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"],
"beam_search": [
"model_llamacpp",
"model_mlxlm",
"model_mlxlm_phi3",
"model_transformers_audio",
],
"multiple_samples": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"],
"cfg": ["model_llamacpp"], # TODO: fix llama_cpp tokenizer
}
Expand Down Expand Up @@ -232,14 +237,20 @@ def get_inputs(fixture_name, batch_size=None):

elif fixture_name.endswith("_audio"):
instruct_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
audio = np.random.random(20000)
batch_prompt = "<|im_start|>assistant\n"
audio = np.random.random(20000).astype(np.float32)

if batch_size is None:
return {"prompts": f"{instruct_prompt}{prompts}", "media": [audio]}
return {
"prompts": f"{instruct_prompt}{prompts}<|im_end|>\n",
"media": [audio],
}
else:
return {
"prompts": [f"{instruct_prompt}{p}" for p in prompts],
"media": [[audio] for _ in range(batch_size)],
"prompts": [
f"{instruct_prompt}{p}<|im_end|>\n{batch_prompt}" for p in prompts
],
"media": [audio for _ in range(batch_size)],
}

else:
Expand Down Expand Up @@ -420,7 +431,9 @@ def test_generate_regex_batch_multi_sample(
generator = generate.regex(
model, pattern, sampler=getattr(samplers, sampler_name)(4)
)
with enforce_not_implemented(model_fixture, "batch", "multiple_samples"):
with enforce_not_implemented(
model_fixture, "batch", "multiple_samples", sampler_name
):
output_batch_groups = generator(**get_inputs(model_fixture, 4), max_tokens=40)
for output_sample_groups in output_batch_groups:
for output in output_sample_groups:
Expand Down
6 changes: 3 additions & 3 deletions tests/generate/test_integration_transformers_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def audio_from_url(url):
@pytest.fixture(scope="session")
def model(tmp_path_factory):
return transformers_audio(
"Qwen/Qwen2-Audio-7B-Instruct",
"yujiepan/qwen2-audio-tiny-random",
model_class=Qwen2AudioForConditionalGeneration,
device="cpu",
)


@pytest.fixture(scope="session")
def processor(tmp_path_factory):
return AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
return AutoProcessor.from_pretrained("yujiepan/qwen2-audio-tiny-random")


def test_single_audio_text_gen(model, processor):
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_single_audio_choice(model, processor):
"role": "user",
"content": [
{"audio"},
{"type": "text", "text": "What is this?"},
{"type": "text", "text": "What's that sound?"},
],
},
]
Expand Down

0 comments on commit 210ac07

Please sign in to comment.