From 445efbb81addced22e4e315fea6320b8fbd5b057 Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Sun, 24 Nov 2024 18:17:33 +0800 Subject: [PATCH] [Bug]: Authorization ignored when root_path is set FIX: #10531 Signed-off-by: chaunceyjiang --- tests/entrypoints/openai/test_root_path.py | 103 +++++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 6 +- 2 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 tests/entrypoints/openai/test_root_path.py diff --git a/tests/entrypoints/openai/test_root_path.py b/tests/entrypoints/openai/test_root_path.py new file mode 100644 index 0000000000000..20f7960619efb --- /dev/null +++ b/tests/entrypoints/openai/test_root_path.py @@ -0,0 +1,103 @@ +import contextlib +import os +from typing import Any, List, NamedTuple + +import openai # use the official client for correctness check +import pytest + +from ...utils import RemoteOpenAIServer + +# # any model with a chat template should work here +MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 +API_KEY = "abc-123" +ERROR_API_KEY = "abc" +ROOT_PATH = "llm" + + +@pytest.fixture(scope="module") +def server(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--enforce-eager", + "--max-model-len", + "4080", + "--root-path", # use --root-path=/llm for testing + "/" + ROOT_PATH, + "--chat-template", + DUMMY_CHAT_TEMPLATE, + ] + envs = os.environ.copy() + + envs["VLLM_API_KEY"] = API_KEY + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server: + yield remote_server + + +class TestCase(NamedTuple): + model_name: str + base_url: List[str] + api_key: str + expected_error: Any + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + model_name=MODEL_NAME, + base_url=["v1"], # http://localhost:8000/v1 + api_key=ERROR_API_KEY, + expected_error=openai.AuthenticationError), + TestCase( + model_name=MODEL_NAME, + base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 + api_key=ERROR_API_KEY, + expected_error=openai.AuthenticationError), + TestCase( + model_name=MODEL_NAME, + base_url=["v1"], # http://localhost:8000/v1 + api_key=API_KEY, + expected_error=None), + TestCase( + model_name=MODEL_NAME, + base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 + api_key=API_KEY, + expected_error=None), + ], +) +async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, + test_case: TestCase): + saying: str = "Here is a common saying about apple. An apple a day, keeps" + ctx = contextlib.nullcontext() + if test_case.expected_error is not None: + ctx = pytest.raises(test_case.expected_error) + with ctx: + client = openai.AsyncOpenAI( + api_key=test_case.api_key, + base_url=server.url_for(*test_case.base_url), + max_retries=0) + chat_completion = await client.chat.completions.create( + model=test_case.model_name, + messages=[{ + "role": "user", + "content": "tell me a common saying" + }, { + "role": "assistant", + "content": saying + }], + extra_body={ + "continue_final_message": True, + "add_generation_prompt": False + }) + + assert chat_completion.id is not None + assert len(chat_completion.choices) == 1 + choice = chat_completion.choices[0] + assert choice.finish_reason == "stop" + message = choice.message + assert len(message.content) > 0 + assert message.role == "assistant" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b0fe061f5db4a..7fe3696a1cf2a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -475,10 +475,12 @@ async def validation_exception_handler(_, exc): @app.middleware("http") async def authentication(request: Request, call_next): - root_path = "" if args.root_path is None else args.root_path if request.method == "OPTIONS": return await call_next(request) - if not request.url.path.startswith(f"{root_path}/v1"): + url_path = request.url.path + if app.root_path and url_path.startswith(app.root_path): + url_path = url_path[len(app.root_path):] + if not url_path.startswith("/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: return JSONResponse(content={"error": "Unauthorized"},