From 210ac07f3dc1262eda3d0981e889b609bdb9aa17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl?= Date: Wed, 4 Dec 2024 17:13:25 +0000 Subject: [PATCH] fix(test): use tiny model for audio transformers --- outlines/generate/api.py | 19 +++++-------- outlines/models/transformers_audio.py | 6 ++++- tests/generate/test_api.py | 7 ++--- tests/generate/test_generate.py | 27 ++++++++++++++----- .../test_integration_transformers_audio.py | 6 ++--- 5 files changed, 39 insertions(+), 26 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index a248162b1..396166622 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -691,23 +691,18 @@ def _validate_prompt_media_types( def valid_types(prompts, media): import numpy as np # type: ignore - if isinstance(prompts, list): - if not isinstance(media, list) or len(prompts) != len(media): - return False - for subprompt, submedia in zip(prompts, media): - if not isinstance(subprompt, str) or not all( - isinstance(m, np.ndarray) for m in submedia - ): - return False - elif isinstance(prompts, str): - if not all(isinstance(m, np.ndarray) for m in media): - return False + if not isinstance(prompts, (str, list)): + return False + if not isinstance(media, list): + return False + if not all(isinstance(m, np.ndarray) for m in media): + return False return True if not valid_types(prompts, media): raise TypeError( "Expected (prompts, media) to be of type " - "(str, List[np.ndarray])), or (List[str], List[List[np.ndarray]]) " + "(str, List[np.ndarray])), or (List[str], List[np.ndarray]]) " f"instead got prompts={prompts}, media={media}" ) diff --git a/outlines/models/transformers_audio.py b/outlines/models/transformers_audio.py index bcfa8d848..39b55b8d1 100644 --- a/outlines/models/transformers_audio.py +++ b/outlines/models/transformers_audio.py @@ -44,7 +44,11 @@ def generate( # type: ignore The generated text """ inputs = self.processor( - text=prompts, audios=media, padding=True, return_tensors="pt" + text=prompts, + audios=media, + padding=True, + return_tensors="pt", + sampling_rate=self.processor.feature_extractor.sampling_rate, ).to(self.model.device) generation_kwargs = self._get_generation_kwargs( diff --git a/tests/generate/test_api.py b/tests/generate/test_api.py index 69d39cd47..881da04ed 100644 --- a/tests/generate/test_api.py +++ b/tests/generate/test_api.py @@ -42,12 +42,13 @@ def test_vision_sequence_generator_validate_types(prompts, media, type_error): "prompts,media,type_error", [ ("single prompt", [AUDIO_ARRAY], False), - (["prompt0", "prompt1"], [[AUDIO_ARRAY], [AUDIO_ARRAY]], False), + (["single prompt"], [AUDIO_ARRAY], False), + (["prompt0", "prompt1"], [AUDIO_ARRAY, AUDIO_ARRAY], False), ("single prompt", [AUDIO_ARRAY, AUDIO_ARRAY], False), - (["prompt0", "prompt1"], [[AUDIO_ARRAY, AUDIO_ARRAY], [AUDIO_ARRAY]], False), ("single prompt", "this isn't an audio, it's a string", True), ("single prompt", AUDIO_ARRAY, True), - (["prompt0", "prompt1"], [AUDIO_ARRAY], True), + (["prompt0", "prompt1"], [AUDIO_ARRAY], False), + ("prompt0", [[AUDIO_ARRAY]], True), (["prompt0", "prompt1"], [[AUDIO_ARRAY]], True), (["prompt0", "prompt1"], [[[AUDIO_ARRAY]], [[AUDIO_ARRAY]]], True), ], diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 348c89511..6b48226ba 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -77,7 +77,7 @@ def model_transformers_audio(tmp_path_factory): from transformers import Qwen2AudioForConditionalGeneration return models.transformers_audio( - "Qwen/Qwen2-Audio-7B-Instruct", + "yujiepan/qwen2-audio-tiny-random", model_class=Qwen2AudioForConditionalGeneration, device="cpu", ) @@ -197,7 +197,12 @@ def enforce_not_implemented(model_fixture, *task_names): "model_transformers_audio", ], "batch": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], - "beam_search": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], + "beam_search": [ + "model_llamacpp", + "model_mlxlm", + "model_mlxlm_phi3", + "model_transformers_audio", + ], "multiple_samples": ["model_llamacpp", "model_mlxlm", "model_mlxlm_phi3"], "cfg": ["model_llamacpp"], # TODO: fix llama_cpp tokenizer } @@ -232,14 +237,20 @@ def get_inputs(fixture_name, batch_size=None): elif fixture_name.endswith("_audio"): instruct_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\n" - audio = np.random.random(20000) + batch_prompt = "<|im_start|>assistant\n" + audio = np.random.random(20000).astype(np.float32) if batch_size is None: - return {"prompts": f"{instruct_prompt}{prompts}", "media": [audio]} + return { + "prompts": f"{instruct_prompt}{prompts}<|im_end|>\n", + "media": [audio], + } else: return { - "prompts": [f"{instruct_prompt}{p}" for p in prompts], - "media": [[audio] for _ in range(batch_size)], + "prompts": [ + f"{instruct_prompt}{p}<|im_end|>\n{batch_prompt}" for p in prompts + ], + "media": [audio for _ in range(batch_size)], } else: @@ -420,7 +431,9 @@ def test_generate_regex_batch_multi_sample( generator = generate.regex( model, pattern, sampler=getattr(samplers, sampler_name)(4) ) - with enforce_not_implemented(model_fixture, "batch", "multiple_samples"): + with enforce_not_implemented( + model_fixture, "batch", "multiple_samples", sampler_name + ): output_batch_groups = generator(**get_inputs(model_fixture, 4), max_tokens=40) for output_sample_groups in output_batch_groups: for output in output_sample_groups: diff --git a/tests/generate/test_integration_transformers_audio.py b/tests/generate/test_integration_transformers_audio.py index 50d022089..d9fe0921c 100644 --- a/tests/generate/test_integration_transformers_audio.py +++ b/tests/generate/test_integration_transformers_audio.py @@ -25,7 +25,7 @@ def audio_from_url(url): @pytest.fixture(scope="session") def model(tmp_path_factory): return transformers_audio( - "Qwen/Qwen2-Audio-7B-Instruct", + "yujiepan/qwen2-audio-tiny-random", model_class=Qwen2AudioForConditionalGeneration, device="cpu", ) @@ -33,7 +33,7 @@ def model(tmp_path_factory): @pytest.fixture(scope="session") def processor(tmp_path_factory): - return AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") + return AutoProcessor.from_pretrained("yujiepan/qwen2-audio-tiny-random") def test_single_audio_text_gen(model, processor): @@ -130,7 +130,7 @@ def test_single_audio_choice(model, processor): "role": "user", "content": [ {"audio"}, - {"type": "text", "text": "What is this?"}, + {"type": "text", "text": "What's that sound?"}, ], }, ]