From b529821a4c8d26970653383feefb580453de31cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl?= Date: Wed, 11 Dec 2024 11:47:15 +0000 Subject: [PATCH] test(audio): improve coverage of validate prompt and media --- outlines/generate/api.py | 7 +++++-- tests/generate/test_api.py | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 396166622..2c8a0d831 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -691,8 +691,11 @@ def _validate_prompt_media_types( def valid_types(prompts, media): import numpy as np # type: ignore - if not isinstance(prompts, (str, list)): - return False + if not isinstance(prompts, str): + if not isinstance(prompts, list): + return False + if not all(isinstance(p, str) for p in prompts): + return False if not isinstance(media, list): return False if not all(isinstance(m, np.ndarray) for m in media): diff --git a/tests/generate/test_api.py b/tests/generate/test_api.py index 881da04ed..4b162a147 100644 --- a/tests/generate/test_api.py +++ b/tests/generate/test_api.py @@ -42,6 +42,8 @@ def test_vision_sequence_generator_validate_types(prompts, media, type_error): "prompts,media,type_error", [ ("single prompt", [AUDIO_ARRAY], False), + (0, [AUDIO_ARRAY], True), + ([AUDIO_ARRAY], "single prompt", True), (["single prompt"], [AUDIO_ARRAY], False), (["prompt0", "prompt1"], [AUDIO_ARRAY, AUDIO_ARRAY], False), ("single prompt", [AUDIO_ARRAY, AUDIO_ARRAY], False),