From a7d6962a58e1d06b3646db76f7ec7b7dc8713ae5 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 18 Dec 2024 23:00:38 -0500 Subject: [PATCH] [CI] Expand test_guided_generate to test all backends (#11313) Signed-off-by: mgoin --- tests/entrypoints/llm/test_guided_generate.py | 112 +++++++++++------- .../model_executor/test_guided_processors.py | 4 +- .../guided_decoding/__init__.py | 64 +++++++++- 3 files changed, 129 insertions(+), 51 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index ed50ec6bbc9eb..e9c48f2b6b551 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -10,7 +10,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" +GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] @pytest.fixture(scope="module") @@ -26,11 +27,13 @@ def llm(): @pytest.mark.skip_global_cleanup -def test_guided_regex(sample_regex, llm): - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams(regex=sample_regex)) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_regex(sample_regex, llm, guided_decoding_backend: str): + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + regex=sample_regex, + backend=guided_decoding_backend)) outputs = llm.generate(prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2, @@ -50,11 +53,14 @@ def test_guided_regex(sample_regex, llm): @pytest.mark.skip_global_cleanup -def test_guided_json_completion(sample_json_schema, llm): - sampling_params = SamplingParams( - temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=sample_json_schema)) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_json_completion(sample_json_schema, llm, + guided_decoding_backend: str): + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_json_schema, + backend=guided_decoding_backend)) outputs = llm.generate(prompts=[ f"Give an example JSON for an employee profile " f"that fits this schema: {sample_json_schema}" @@ -77,11 +83,14 @@ def test_guided_json_completion(sample_json_schema, llm): @pytest.mark.skip_global_cleanup -def test_guided_complex_json_completion(sample_complex_json_schema, llm): - sampling_params = SamplingParams( - temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=sample_complex_json_schema)) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_complex_json_completion(sample_complex_json_schema, llm, + guided_decoding_backend: str): + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_complex_json_schema, + backend=guided_decoding_backend)) outputs = llm.generate(prompts=[ f"Give an example JSON for an assignment grade " f"that fits this schema: {sample_complex_json_schema}" @@ -105,11 +114,14 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm): @pytest.mark.skip_global_cleanup -def test_guided_definition_json_completion(sample_definition_json_schema, llm): +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_definition_json_completion(sample_definition_json_schema, llm, + guided_decoding_backend: str): sampling_params = SamplingParams(temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( - json=sample_definition_json_schema)) + json=sample_definition_json_schema, + backend=guided_decoding_backend)) outputs = llm.generate(prompts=[ f"Give an example JSON for solving 8x + 7 = -23 " f"that fits this schema: {sample_definition_json_schema}" @@ -133,11 +145,14 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm): @pytest.mark.skip_global_cleanup -def test_guided_choice_completion(sample_guided_choice, llm): - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_choice_completion(sample_guided_choice, llm, + guided_decoding_backend: str): + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + choice=sample_guided_choice, + backend=guided_decoding_backend)) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, @@ -156,13 +171,20 @@ def test_guided_choice_completion(sample_guided_choice, llm): @pytest.mark.skip_global_cleanup -def test_guided_grammar(sample_sql_statements, llm): - - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements)) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_grammar(sample_sql_statements, llm, + guided_decoding_backend: str): + if guided_decoding_backend == "outlines": + pytest.skip("Outlines backend fails in this test case with:\n" + "AttributeError: Error in model execution: 'ParserConf' " + "object has no attribute 'deterministic'") + + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_statements, + backend=guided_decoding_backend)) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " "table_1 where it is equals to 1"), @@ -218,15 +240,18 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm): @pytest.mark.skip_global_cleanup -def test_guided_json_object(llm): - sampling_params = SamplingParams( - temperature=1.0, - max_tokens=100, - guided_decoding=GuidedDecodingParams(json_object=True)) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_json_object(llm, guided_decoding_backend: str): + sampling_params = SamplingParams(temperature=1.0, + max_tokens=100, + n=2, + guided_decoding=GuidedDecodingParams( + json_object=True, + backend=guided_decoding_backend)) outputs = llm.generate( - prompts=("Generate a JSON object describing a person with name " - "and age for John Smith who is 31 years old."), + prompts=("Generate a JSON object with curly braces for a person with " + "name and age fields for John Smith who is 31 years old."), sampling_params=sampling_params, use_tqdm=True) @@ -235,10 +260,11 @@ def test_guided_json_object(llm): assert output is not None assert isinstance(output, RequestOutput) - generated_text = output.outputs[0].text - print(generated_text) - assert generated_text is not None + for i in range(2): + generated_text = output.outputs[i].text + print(generated_text) + assert generated_text is not None - # Parse to verify it is valid JSON - parsed_json = json.loads(generated_text) - assert isinstance(parsed_json, dict) + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 3334c0df149b5..be5282d9c8223 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -13,6 +13,7 @@ from vllm.sampling_params import GuidedDecodingParams MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' +GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] def test_guided_logits_processors(sample_regex, sample_json_schema): @@ -42,8 +43,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio -@pytest.mark.parametrize("backend", - ["outlines", "lm-format-enforcer", "xgrammar"]) +@pytest.mark.parametrize("backend", GUIDED_DECODING_BACKENDS) @pytest.mark.parametrize("is_local", [True, False]) async def test_guided_logits_processor_black_box(backend: str, is_local: bool, sample_regex, diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index e631aec928ec5..550b892303feb 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -49,15 +49,60 @@ def check_object(obj: dict) -> bool: return check_object(schema) +def has_lmf_unsupported_json_features(schema: dict) -> bool: + """ + Check if JSON schema contains features unsupported + by lm_format_enforcer. + + Known issues: + - Regex patterns: + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + """ + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + def maybe_backend_fallback( guided_params: GuidedDecodingParams) -> GuidedDecodingParams: # lm-format-enforce doesn't support grammar, fallback to xgrammar - if (guided_params.backend == "lm-format-enforcer" - and guided_params.grammar is not None): - logger.warning( - "lm-format-enforcer does not support grammar guided decoding. " - "Falling back to use xgrammar instead.") - guided_params.backend = "xgrammar" + if guided_params.backend == "lm-format-enforcer": + if guided_params.grammar is not None: + logger.warning( + "lm-format-enforcer does not support grammar guided decoding. " + "Falling back to use xgrammar instead.") + guided_params.backend = "xgrammar" + + # lm-format-enforcer doesn't support some JSON schema features + elif (guided_params.json is not None + and has_lmf_unsupported_json_features(guided_params.json)): + logger.warning( + "lm-format-enforcer does not support advanced JSON schema " + "features like patterns or numeric ranges. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" if guided_params.backend == "xgrammar": # xgrammar only has x86 wheels for linux, fallback to outlines @@ -82,6 +127,13 @@ def maybe_backend_fallback( "Falling back to use outlines instead.") guided_params.backend = "outlines" + if (guided_params.backend == "outlines" + and guided_params.json_object is not None): + # outlines doesn't support json_object, fallback to xgrammar + logger.warning("outlines does not support json_object. " + "Falling back to use xgrammar instead.") + guided_params.backend = "xgrammar" + return guided_params