From b6610fb121ff75772ee9edc61cd30e22c057d851 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 15:26:21 -0600 Subject: [PATCH 01/23] feat: serving_rerank implementation Signed-off-by: Kyle Mistele --- vllm/entrypoints/openai/serving_rerank.py | 201 ++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 vllm/entrypoints/openai/serving_rerank.py diff --git a/vllm/entrypoints/openai/serving_rerank.py b/vllm/entrypoints/openai/serving_rerank.py new file mode 100644 index 0000000000000..485a0ec30a3e7 --- /dev/null +++ b/vllm/entrypoints/openai/serving_rerank.py @@ -0,0 +1,201 @@ +import asyncio +import time +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, + RerankRequest, RerankResponse, + RerankResult, RerankUsage) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import make_async, merge_async_iterators + +logger = init_logger(__name__) + + +class JinaAIServingRerank(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + async def do_rerank( + self, + request: RerankRequest, + raw_request: Optional[Request] = None + ) -> Union[RerankResponse, ErrorResponse]: + """ + Rerank API based on JinaAI's rerank API; implements the same + API interface. Designed for compatibility with off-the-shelf + tooling, since this is a common standard for reranking APIs + + See example client implementations at + https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py + numerous clients use this standard. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"rerank-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + truncate_prompt_tokens = request.truncate_prompt_tokens + query = request.query + documents = request.documents + request_prompts = [] + engine_prompts = [] + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for scoring models") + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + if not self.model_config.is_cross_encoder: + raise ValueError("Model is not cross encoder.") + + if truncate_prompt_tokens is not None and \ + truncate_prompt_tokens > self.max_model_len: + raise ValueError( + f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " + f"is greater than max_model_len ({self.max_model_len})." + f" Please, select a smaller truncation size.") + for doc in documents: + request_prompt = f"{query}{tokenizer.sep_token}{doc}" + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + prompt_inputs = await tokenize_async(text=query, + text_pair=doc, + **tokenization_kwargs) + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_rerank_response( + final_res_batch_checked, request_id, created_time, model_name, + documents) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_rerank_response( + self, final_res_batch: List[PoolingRequestOutput], request_id: str, + model_name: str, documents: List[str]) -> RerankResponse: + """ + Convert the output of do_rank to a RerankResponse + """ + results: List[RerankResult] = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + result = RerankResult( + index=idx, + document=RerankDocument(text=documents[idx]), + relevance_score=classify_res.outputs.score, + ) + results.append(result) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + return RerankResponse( + id=request_id, + model=model_name, + results=results, + usage=RerankUsage(total_tokens=num_prompt_tokens)) From a82b4bbb5e6935d83de99dd4e1c4cd26397abb88 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 17:00:25 -0600 Subject: [PATCH 02/23] fix: imports Signed-off-by: Kyle Mistele --- vllm/entrypoints/openai/api_server.py | 41 ++++++++++++++++ vllm/entrypoints/openai/protocol.py | 58 +++++++++++++++++++++++ vllm/entrypoints/openai/serving_engine.py | 9 ++-- vllm/entrypoints/openai/serving_rerank.py | 15 ++++-- 4 files changed, 114 insertions(+), 9 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f510c41503011..7102976b625bc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -56,6 +56,7 @@ PoolingChatRequest, PoolingCompletionRequest, PoolingRequest, PoolingResponse, + RerankRequest, RerankResponse, ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, @@ -68,6 +69,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) @@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]: return request.app.state.openai_serving_scores +def rerank(request: Request) -> Optional[JinaAIServingRerank]: + return request.app.state.jinaai_serving_reranking + + def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization @@ -502,6 +508,33 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +@router.post("/rerank") +@with_cancellation +async def do_rerank(request: RerankRequest, raw_request: Request): + handler = rerank(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Rerank (Score) API") + generator = await handler.do_rerank(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, RerankResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/rerank") +@with_cancellation +async def do_rerank_v1(request: RerankRequest, raw_request: Request): + logger.warning( + "To indicate that the rerank API is not part of the standard OpenAI" + " API, we have located it at `/rerank`. Please update your client" + "accordingly. (Note: Conforms to JinaAI rerank API)") + return await do_rerank(request, raw_request) + + TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { "generate": { "messages": (ChatCompletionRequest, create_chat_completion), @@ -514,6 +547,9 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): "score": { "default": (ScoreRequest, create_score), }, + "rerank": { + "default": (RerankRequest, do_rerank) + }, "reward": { "messages": (PoolingChatRequest, create_pooling), "default": (PoolingCompletionRequest, create_pooling), @@ -759,6 +795,11 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger ) if model_config.task == "score" else None + state.jinaai_serving_reranking = JinaAIServingRerank( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 80403f77d5375..2487e6bf638c4 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1000,6 +1000,64 @@ def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) +class RerankRequest(OpenAIBaseModel): + model: str + query: str + documents: List[str] + top_n: int = Field(default_factory=lambda: 0) + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-rerank-pooling-params + additional_data: Optional[Any] = None + # doc: end-rerank-pooling-params + + # doc: begin-rerank-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-rerank-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + @classmethod + def __get_validators__(cls): + yield cls.validate_top_n + + # validator to set the top_n value to the length of the documents if not set + @classmethod + def validate_top_n(cls, values): + # the lambda sets the field to zero if it's not set + if values.get('top_n') == 0: + values['top_n'] = len(values.get('documents', [])) + return values + + +class RerankDocument(BaseModel): + text: str + + +class RerankResult(BaseModel): + index: int + document: RerankDocument + relevance_score: float + + +class RerankUsage(BaseModel): + total_tokens: int + + +class RerankResponse(OpenAIBaseModel): + id: str + model: str + usage: RerankUsage + results: List[RerankResult] + + class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3da447be06430..8d54164e500eb 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -26,7 +26,8 @@ DetokenizeRequest, EmbeddingChatRequest, EmbeddingCompletionRequest, - ErrorResponse, ScoreRequest, + ErrorResponse, RerankRequest, + ScoreRequest, TokenizeChatRequest, TokenizeCompletionRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels @@ -204,9 +205,9 @@ def _validate_input( token_num = len(input_ids) # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens - if isinstance( - request, - (EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest)): + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest, + ScoreRequest, RerankRequest)): operation = "score" if isinstance(request, ScoreRequest) \ else "embedding generation" diff --git a/vllm/entrypoints/openai/serving_rerank.py b/vllm/entrypoints/openai/serving_rerank.py index 485a0ec30a3e7..eff06725250d0 100644 --- a/vllm/entrypoints/openai/serving_rerank.py +++ b/vllm/entrypoints/openai/serving_rerank.py @@ -1,5 +1,4 @@ import asyncio -import time from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast from fastapi import Request @@ -56,12 +55,12 @@ async def do_rerank( model_name = request.model request_id = f"rerank-{self._base_request_id(raw_request)}" - created_time = int(time.time()) truncate_prompt_tokens = request.truncate_prompt_tokens query = request.query documents = request.documents request_prompts = [] engine_prompts = [] + top_n = request.top_n try: ( @@ -164,8 +163,8 @@ async def do_rerank( final_res_batch) response = self.request_output_to_rerank_response( - final_res_batch_checked, request_id, created_time, model_name, - documents) + final_res_batch_checked, request_id, model_name, documents, + top_n) except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: @@ -176,7 +175,8 @@ async def do_rerank( def request_output_to_rerank_response( self, final_res_batch: List[PoolingRequestOutput], request_id: str, - model_name: str, documents: List[str]) -> RerankResponse: + model_name: str, documents: List[str], + top_n: int) -> RerankResponse: """ Convert the output of do_rank to a RerankResponse """ @@ -194,6 +194,11 @@ def request_output_to_rerank_response( prompt_token_ids = final_res.prompt_token_ids num_prompt_tokens += len(prompt_token_ids) + # sort by relevance, then return the top n if set + results.sort(key=lambda x: x.relevance_score, reverse=True) + if top_n < len(documents): + results = results[:top_n] + return RerankResponse( id=request_id, model=model_name, From 99acff6adc75b2931742a4a70225b130b13b73dd Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 17:32:01 -0600 Subject: [PATCH 03/23] doc: add example requests and scripts Signed-off-by: Kyle Mistele --- .../serving/openai_compatible_server.md | 88 +++++++++++++++++++ .../online_serving/jinjaai_rerank_client.py | 28 ++++++ vllm/entrypoints/openai/protocol.py | 12 --- vllm/entrypoints/openai/serving_rerank.py | 2 +- 4 files changed, 117 insertions(+), 13 deletions(-) create mode 100644 examples/online_serving/jinjaai_rerank_client.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index e49bbb06695f8..1a0c17321d6eb 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -50,6 +50,9 @@ In addition, we have the following custom APIs: - Applicable to all [pooling models](../models/pooling_models.md). - [Score API](#score-api) (`/score`) - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). +- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`) + - Implements [Jina AI's rerank API](https://jina.ai/reranker/) which is a common standard for re-rank APIs + - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). (chat-template)= @@ -473,3 +476,88 @@ The following extra parameters are supported: :start-after: begin-score-extra-params :end-before: end-score-extra-params ``` + +(rerank-api) = + +### Re-rank API + +Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and +each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on +a scale of 0 to 1. + +You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). + +Compatible with popular re-rank models such as `BAAI/bge-reranker-base`, the `/rerank` and `/v1/rerank` +endpoints implement [Jina AI's re-rank API interface](https://jina.ai/reranker/) to ensure compatibility with +popular open-source tools. + +Code example: + +#### Example Request + +Note that the `top_n` request parameter is optional and will default to the length of the `documents` field. +Result documents will be sorted by relevance, and the `index` property can be used to determine original order. + +Request: + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/rerank' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-base", + "query": "What is the capital of France?", + "documents": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", + "Horses and cows are both animals" + ] +}' +``` + +Response: + +```bash +{ + "id": "rerank-fae51b2b664d4ed38f5969b612edff77", + "model": "BAAI/bge-reranker-base", + "usage": { + "total_tokens": 56 + }, + "results": [ + { + "index": 1, + "document": { + "text": "The capital of France is Paris." + }, + "relevance_score": 0.99853515625 + }, + { + "index": 0, + "document": { + "text": "The capital of Brazil is Brasilia." + }, + "relevance_score": 0.0005860328674316406 + } + ] +} +``` + +#### Extra parameters + +The following [pooling parameters](#pooling-params) are supported. + +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-rerank-pooling-params +:end-before: end-rerank-pooling-params +``` + +The following extra parameters are supported: + +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-rerank-extra-params +:end-before: end-rerank-extra-params +``` diff --git a/examples/online_serving/jinjaai_rerank_client.py b/examples/online_serving/jinjaai_rerank_client.py new file mode 100644 index 0000000000000..617fa9e0a08fc --- /dev/null +++ b/examples/online_serving/jinjaai_rerank_client.py @@ -0,0 +1,28 @@ +import json + +import requests + +url = "http://127.0.0.1:8000/rerank" + +headers = {"accept": "application/json", "Content-Type": "application/json"} + +data = { + "model": + "BAAI/bge-reranker-base", + "query": + "What is the capital of France?", + "documents": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", "Horses and cows are both animals" + ] +} + +response = requests.post(url, headers=headers, json=data) + +# Check the response +if response.status_code == 200: + print("Request successful!") + print(json.dumps(response.json(), indent=2)) +else: + print(f"Request failed with status code: {response.status_code}") + print(response.text) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2487e6bf638c4..c3cfa876f5788 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1024,18 +1024,6 @@ class RerankRequest(OpenAIBaseModel): def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) - @classmethod - def __get_validators__(cls): - yield cls.validate_top_n - - # validator to set the top_n value to the length of the documents if not set - @classmethod - def validate_top_n(cls, values): - # the lambda sets the field to zero if it's not set - if values.get('top_n') == 0: - values['top_n'] = len(values.get('documents', [])) - return values - class RerankDocument(BaseModel): text: str diff --git a/vllm/entrypoints/openai/serving_rerank.py b/vllm/entrypoints/openai/serving_rerank.py index eff06725250d0..be4420261afe3 100644 --- a/vllm/entrypoints/openai/serving_rerank.py +++ b/vllm/entrypoints/openai/serving_rerank.py @@ -60,7 +60,7 @@ async def do_rerank( documents = request.documents request_prompts = [] engine_prompts = [] - top_n = request.top_n + top_n = request.top_n if request.top_n > 0 else len(documents) try: ( From 31b5137e751aa95861a158215a9e08f5ac224976 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 19:28:52 -0600 Subject: [PATCH 04/23] test: rerank also, add documentation and update client with instructions Signed-off-by: Kyle Mistele --- .../serving/openai_compatible_server.md | 10 +- .../online_serving/jinjaai_rerank_client.py | 5 + tests/entrypoints/openai/test_rerank.py | 98 +++++++++++++++++++ 3 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 tests/entrypoints/openai/test_rerank.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 1a0c17321d6eb..3f8edfb0e7105 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -51,7 +51,9 @@ In addition, we have the following custom APIs: - [Score API](#score-api) (`/score`) - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). - [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`) - - Implements [Jina AI's rerank API](https://jina.ai/reranker/) which is a common standard for re-rank APIs + - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) + - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) + - Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response. - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). (chat-template)= @@ -487,8 +489,10 @@ a scale of 0 to 1. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). -Compatible with popular re-rank models such as `BAAI/bge-reranker-base`, the `/rerank` and `/v1/rerank` -endpoints implement [Jina AI's re-rank API interface](https://jina.ai/reranker/) to ensure compatibility with +The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the +`score` task. Additionally, both `/rerank` and `/v1/rerank` endpoints +endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and +[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with popular open-source tools. Code example: diff --git a/examples/online_serving/jinjaai_rerank_client.py b/examples/online_serving/jinjaai_rerank_client.py index 617fa9e0a08fc..84ff7461146eb 100644 --- a/examples/online_serving/jinjaai_rerank_client.py +++ b/examples/online_serving/jinjaai_rerank_client.py @@ -1,3 +1,8 @@ +""" +Example of using the OpenAI entrypoint's rerank API which is compatible with +Jina and Cohere +run: vllm serve --model BAAI/bge-reranker-base +""" import json import requests diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py new file mode 100644 index 0000000000000..0d6cde6c05ad9 --- /dev/null +++ b/tests/entrypoints/openai/test_rerank.py @@ -0,0 +1,98 @@ +import pytest +import requests + +from vllm.entrypoints.openai.protocol import RerankResponse + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "BAAI/bge-reranker-base" + + +@pytest.fixture(scope="module") +def server(): + args = ['--enforce-eager', '--max-model-len 100'] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + rerank_response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + }) + rerank_response.raise_for_status() + rerank = RerankResponse.model_validate(rerank_response.json()) + + assert rerank.id is not None + assert rerank.results is not None + assert len(rerank.results) == 2 + assert rerank.results[1].relevance_score <= 0.01 + assert rerank.results[0].relevance_score >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_top_n(server: RemoteOpenAIServer, model_name: str): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", "Cross-encoder models are neat" + ] + + rerank_response = requests.post(server.url_for("score"), + json={ + "model": model_name, + "query": query, + "documents": documents, + "top_n": 2 + }) + rerank_response.raise_for_status() + rerank = RerankResponse.model_validate(rerank_response.json()) + + assert rerank.id is not None + assert rerank.results is not None + assert len(rerank.results) == 2 + assert rerank.results[1].relevance_score <= 0.01 + assert rerank.results[0].relevance_score >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str): + + query = "What is the capital of France?" * 100 + documents = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + rerank_response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents + }) + assert rerank_response.status_code == 400 + # Assert just a small fragments of the response + assert "Please reduce the length of the input." in \ + rerank_response.text + + # Test truncation + rerank_response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents + }) + assert rerank_response.status_code == 400 + assert "Please, select a smaller truncation size." in \ + rerank_response.text From 485e3287f7f7f8cc29128af785c82aca24abfdc7 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 15:26:21 -0600 Subject: [PATCH 05/23] feat: serving_rerank implementation Signed-off-by: Kyle Mistele --- vllm/entrypoints/openai/serving_rerank.py | 201 ++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 vllm/entrypoints/openai/serving_rerank.py diff --git a/vllm/entrypoints/openai/serving_rerank.py b/vllm/entrypoints/openai/serving_rerank.py new file mode 100644 index 0000000000000..485a0ec30a3e7 --- /dev/null +++ b/vllm/entrypoints/openai/serving_rerank.py @@ -0,0 +1,201 @@ +import asyncio +import time +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, + RerankRequest, RerankResponse, + RerankResult, RerankUsage) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import make_async, merge_async_iterators + +logger = init_logger(__name__) + + +class JinaAIServingRerank(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + async def do_rerank( + self, + request: RerankRequest, + raw_request: Optional[Request] = None + ) -> Union[RerankResponse, ErrorResponse]: + """ + Rerank API based on JinaAI's rerank API; implements the same + API interface. Designed for compatibility with off-the-shelf + tooling, since this is a common standard for reranking APIs + + See example client implementations at + https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py + numerous clients use this standard. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"rerank-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + truncate_prompt_tokens = request.truncate_prompt_tokens + query = request.query + documents = request.documents + request_prompts = [] + engine_prompts = [] + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for scoring models") + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + if not self.model_config.is_cross_encoder: + raise ValueError("Model is not cross encoder.") + + if truncate_prompt_tokens is not None and \ + truncate_prompt_tokens > self.max_model_len: + raise ValueError( + f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " + f"is greater than max_model_len ({self.max_model_len})." + f" Please, select a smaller truncation size.") + for doc in documents: + request_prompt = f"{query}{tokenizer.sep_token}{doc}" + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + prompt_inputs = await tokenize_async(text=query, + text_pair=doc, + **tokenization_kwargs) + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_rerank_response( + final_res_batch_checked, request_id, created_time, model_name, + documents) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_rerank_response( + self, final_res_batch: List[PoolingRequestOutput], request_id: str, + model_name: str, documents: List[str]) -> RerankResponse: + """ + Convert the output of do_rank to a RerankResponse + """ + results: List[RerankResult] = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + result = RerankResult( + index=idx, + document=RerankDocument(text=documents[idx]), + relevance_score=classify_res.outputs.score, + ) + results.append(result) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + return RerankResponse( + id=request_id, + model=model_name, + results=results, + usage=RerankUsage(total_tokens=num_prompt_tokens)) From 8922f81a5fa4ae78a5eb7c850f51822389100d12 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 17:00:25 -0600 Subject: [PATCH 06/23] fix: imports Signed-off-by: Kyle Mistele --- vllm/entrypoints/openai/api_server.py | 41 ++++++++++++++++ vllm/entrypoints/openai/protocol.py | 58 +++++++++++++++++++++++ vllm/entrypoints/openai/serving_engine.py | 9 ++-- vllm/entrypoints/openai/serving_rerank.py | 15 ++++-- 4 files changed, 114 insertions(+), 9 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f510c41503011..7102976b625bc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -56,6 +56,7 @@ PoolingChatRequest, PoolingCompletionRequest, PoolingRequest, PoolingResponse, + RerankRequest, RerankResponse, ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, @@ -68,6 +69,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) @@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]: return request.app.state.openai_serving_scores +def rerank(request: Request) -> Optional[JinaAIServingRerank]: + return request.app.state.jinaai_serving_reranking + + def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization @@ -502,6 +508,33 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +@router.post("/rerank") +@with_cancellation +async def do_rerank(request: RerankRequest, raw_request: Request): + handler = rerank(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Rerank (Score) API") + generator = await handler.do_rerank(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, RerankResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/rerank") +@with_cancellation +async def do_rerank_v1(request: RerankRequest, raw_request: Request): + logger.warning( + "To indicate that the rerank API is not part of the standard OpenAI" + " API, we have located it at `/rerank`. Please update your client" + "accordingly. (Note: Conforms to JinaAI rerank API)") + return await do_rerank(request, raw_request) + + TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { "generate": { "messages": (ChatCompletionRequest, create_chat_completion), @@ -514,6 +547,9 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): "score": { "default": (ScoreRequest, create_score), }, + "rerank": { + "default": (RerankRequest, do_rerank) + }, "reward": { "messages": (PoolingChatRequest, create_pooling), "default": (PoolingCompletionRequest, create_pooling), @@ -759,6 +795,11 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger ) if model_config.task == "score" else None + state.jinaai_serving_reranking = JinaAIServingRerank( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 80403f77d5375..2487e6bf638c4 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1000,6 +1000,64 @@ def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) +class RerankRequest(OpenAIBaseModel): + model: str + query: str + documents: List[str] + top_n: int = Field(default_factory=lambda: 0) + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-rerank-pooling-params + additional_data: Optional[Any] = None + # doc: end-rerank-pooling-params + + # doc: begin-rerank-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-rerank-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + @classmethod + def __get_validators__(cls): + yield cls.validate_top_n + + # validator to set the top_n value to the length of the documents if not set + @classmethod + def validate_top_n(cls, values): + # the lambda sets the field to zero if it's not set + if values.get('top_n') == 0: + values['top_n'] = len(values.get('documents', [])) + return values + + +class RerankDocument(BaseModel): + text: str + + +class RerankResult(BaseModel): + index: int + document: RerankDocument + relevance_score: float + + +class RerankUsage(BaseModel): + total_tokens: int + + +class RerankResponse(OpenAIBaseModel): + id: str + model: str + usage: RerankUsage + results: List[RerankResult] + + class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3da447be06430..8d54164e500eb 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -26,7 +26,8 @@ DetokenizeRequest, EmbeddingChatRequest, EmbeddingCompletionRequest, - ErrorResponse, ScoreRequest, + ErrorResponse, RerankRequest, + ScoreRequest, TokenizeChatRequest, TokenizeCompletionRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels @@ -204,9 +205,9 @@ def _validate_input( token_num = len(input_ids) # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens - if isinstance( - request, - (EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest)): + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest, + ScoreRequest, RerankRequest)): operation = "score" if isinstance(request, ScoreRequest) \ else "embedding generation" diff --git a/vllm/entrypoints/openai/serving_rerank.py b/vllm/entrypoints/openai/serving_rerank.py index 485a0ec30a3e7..eff06725250d0 100644 --- a/vllm/entrypoints/openai/serving_rerank.py +++ b/vllm/entrypoints/openai/serving_rerank.py @@ -1,5 +1,4 @@ import asyncio -import time from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast from fastapi import Request @@ -56,12 +55,12 @@ async def do_rerank( model_name = request.model request_id = f"rerank-{self._base_request_id(raw_request)}" - created_time = int(time.time()) truncate_prompt_tokens = request.truncate_prompt_tokens query = request.query documents = request.documents request_prompts = [] engine_prompts = [] + top_n = request.top_n try: ( @@ -164,8 +163,8 @@ async def do_rerank( final_res_batch) response = self.request_output_to_rerank_response( - final_res_batch_checked, request_id, created_time, model_name, - documents) + final_res_batch_checked, request_id, model_name, documents, + top_n) except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: @@ -176,7 +175,8 @@ async def do_rerank( def request_output_to_rerank_response( self, final_res_batch: List[PoolingRequestOutput], request_id: str, - model_name: str, documents: List[str]) -> RerankResponse: + model_name: str, documents: List[str], + top_n: int) -> RerankResponse: """ Convert the output of do_rank to a RerankResponse """ @@ -194,6 +194,11 @@ def request_output_to_rerank_response( prompt_token_ids = final_res.prompt_token_ids num_prompt_tokens += len(prompt_token_ids) + # sort by relevance, then return the top n if set + results.sort(key=lambda x: x.relevance_score, reverse=True) + if top_n < len(documents): + results = results[:top_n] + return RerankResponse( id=request_id, model=model_name, From dc0d158c5e10c1fee53c2221b7501c0f4c99f52f Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 17:32:01 -0600 Subject: [PATCH 07/23] doc: add example requests and scripts Signed-off-by: Kyle Mistele --- .../serving/openai_compatible_server.md | 88 +++++++++++++++++++ .../online_serving/jinjaai_rerank_client.py | 28 ++++++ vllm/entrypoints/openai/protocol.py | 12 --- vllm/entrypoints/openai/serving_rerank.py | 2 +- 4 files changed, 117 insertions(+), 13 deletions(-) create mode 100644 examples/online_serving/jinjaai_rerank_client.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index e49bbb06695f8..1a0c17321d6eb 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -50,6 +50,9 @@ In addition, we have the following custom APIs: - Applicable to all [pooling models](../models/pooling_models.md). - [Score API](#score-api) (`/score`) - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). +- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`) + - Implements [Jina AI's rerank API](https://jina.ai/reranker/) which is a common standard for re-rank APIs + - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). (chat-template)= @@ -473,3 +476,88 @@ The following extra parameters are supported: :start-after: begin-score-extra-params :end-before: end-score-extra-params ``` + +(rerank-api) = + +### Re-rank API + +Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and +each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on +a scale of 0 to 1. + +You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). + +Compatible with popular re-rank models such as `BAAI/bge-reranker-base`, the `/rerank` and `/v1/rerank` +endpoints implement [Jina AI's re-rank API interface](https://jina.ai/reranker/) to ensure compatibility with +popular open-source tools. + +Code example: + +#### Example Request + +Note that the `top_n` request parameter is optional and will default to the length of the `documents` field. +Result documents will be sorted by relevance, and the `index` property can be used to determine original order. + +Request: + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/rerank' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-base", + "query": "What is the capital of France?", + "documents": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", + "Horses and cows are both animals" + ] +}' +``` + +Response: + +```bash +{ + "id": "rerank-fae51b2b664d4ed38f5969b612edff77", + "model": "BAAI/bge-reranker-base", + "usage": { + "total_tokens": 56 + }, + "results": [ + { + "index": 1, + "document": { + "text": "The capital of France is Paris." + }, + "relevance_score": 0.99853515625 + }, + { + "index": 0, + "document": { + "text": "The capital of Brazil is Brasilia." + }, + "relevance_score": 0.0005860328674316406 + } + ] +} +``` + +#### Extra parameters + +The following [pooling parameters](#pooling-params) are supported. + +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-rerank-pooling-params +:end-before: end-rerank-pooling-params +``` + +The following extra parameters are supported: + +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-rerank-extra-params +:end-before: end-rerank-extra-params +``` diff --git a/examples/online_serving/jinjaai_rerank_client.py b/examples/online_serving/jinjaai_rerank_client.py new file mode 100644 index 0000000000000..617fa9e0a08fc --- /dev/null +++ b/examples/online_serving/jinjaai_rerank_client.py @@ -0,0 +1,28 @@ +import json + +import requests + +url = "http://127.0.0.1:8000/rerank" + +headers = {"accept": "application/json", "Content-Type": "application/json"} + +data = { + "model": + "BAAI/bge-reranker-base", + "query": + "What is the capital of France?", + "documents": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", "Horses and cows are both animals" + ] +} + +response = requests.post(url, headers=headers, json=data) + +# Check the response +if response.status_code == 200: + print("Request successful!") + print(json.dumps(response.json(), indent=2)) +else: + print(f"Request failed with status code: {response.status_code}") + print(response.text) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2487e6bf638c4..c3cfa876f5788 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1024,18 +1024,6 @@ class RerankRequest(OpenAIBaseModel): def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) - @classmethod - def __get_validators__(cls): - yield cls.validate_top_n - - # validator to set the top_n value to the length of the documents if not set - @classmethod - def validate_top_n(cls, values): - # the lambda sets the field to zero if it's not set - if values.get('top_n') == 0: - values['top_n'] = len(values.get('documents', [])) - return values - class RerankDocument(BaseModel): text: str diff --git a/vllm/entrypoints/openai/serving_rerank.py b/vllm/entrypoints/openai/serving_rerank.py index eff06725250d0..be4420261afe3 100644 --- a/vllm/entrypoints/openai/serving_rerank.py +++ b/vllm/entrypoints/openai/serving_rerank.py @@ -60,7 +60,7 @@ async def do_rerank( documents = request.documents request_prompts = [] engine_prompts = [] - top_n = request.top_n + top_n = request.top_n if request.top_n > 0 else len(documents) try: ( From 4ed459bdc46cc52a9fedacbe3eb6a6cadc1c5f55 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 19:28:52 -0600 Subject: [PATCH 08/23] test: rerank also, add documentation and update client with instructions Signed-off-by: Kyle Mistele --- .../serving/openai_compatible_server.md | 10 +- .../online_serving/jinjaai_rerank_client.py | 5 + tests/entrypoints/openai/test_rerank.py | 98 +++++++++++++++++++ 3 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 tests/entrypoints/openai/test_rerank.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 1a0c17321d6eb..3f8edfb0e7105 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -51,7 +51,9 @@ In addition, we have the following custom APIs: - [Score API](#score-api) (`/score`) - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). - [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`) - - Implements [Jina AI's rerank API](https://jina.ai/reranker/) which is a common standard for re-rank APIs + - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) + - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) + - Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response. - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). (chat-template)= @@ -487,8 +489,10 @@ a scale of 0 to 1. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). -Compatible with popular re-rank models such as `BAAI/bge-reranker-base`, the `/rerank` and `/v1/rerank` -endpoints implement [Jina AI's re-rank API interface](https://jina.ai/reranker/) to ensure compatibility with +The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the +`score` task. Additionally, both `/rerank` and `/v1/rerank` endpoints +endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and +[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with popular open-source tools. Code example: diff --git a/examples/online_serving/jinjaai_rerank_client.py b/examples/online_serving/jinjaai_rerank_client.py index 617fa9e0a08fc..84ff7461146eb 100644 --- a/examples/online_serving/jinjaai_rerank_client.py +++ b/examples/online_serving/jinjaai_rerank_client.py @@ -1,3 +1,8 @@ +""" +Example of using the OpenAI entrypoint's rerank API which is compatible with +Jina and Cohere +run: vllm serve --model BAAI/bge-reranker-base +""" import json import requests diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py new file mode 100644 index 0000000000000..0d6cde6c05ad9 --- /dev/null +++ b/tests/entrypoints/openai/test_rerank.py @@ -0,0 +1,98 @@ +import pytest +import requests + +from vllm.entrypoints.openai.protocol import RerankResponse + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "BAAI/bge-reranker-base" + + +@pytest.fixture(scope="module") +def server(): + args = ['--enforce-eager', '--max-model-len 100'] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + rerank_response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + }) + rerank_response.raise_for_status() + rerank = RerankResponse.model_validate(rerank_response.json()) + + assert rerank.id is not None + assert rerank.results is not None + assert len(rerank.results) == 2 + assert rerank.results[1].relevance_score <= 0.01 + assert rerank.results[0].relevance_score >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_top_n(server: RemoteOpenAIServer, model_name: str): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", "Cross-encoder models are neat" + ] + + rerank_response = requests.post(server.url_for("score"), + json={ + "model": model_name, + "query": query, + "documents": documents, + "top_n": 2 + }) + rerank_response.raise_for_status() + rerank = RerankResponse.model_validate(rerank_response.json()) + + assert rerank.id is not None + assert rerank.results is not None + assert len(rerank.results) == 2 + assert rerank.results[1].relevance_score <= 0.01 + assert rerank.results[0].relevance_score >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str): + + query = "What is the capital of France?" * 100 + documents = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + rerank_response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents + }) + assert rerank_response.status_code == 400 + # Assert just a small fragments of the response + assert "Please reduce the length of the input." in \ + rerank_response.text + + # Test truncation + rerank_response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents + }) + assert rerank_response.status_code == 400 + assert "Please, select a smaller truncation size." in \ + rerank_response.text From 676eea057d6bbf3e00d1b020bca75fe9aa1732a3 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 22:35:29 -0600 Subject: [PATCH 09/23] added /v2/rerank route Signed-off-by: Kyle Mistele --- vllm/entrypoints/openai/api_server.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7102976b625bc..41e14af240b38 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -535,6 +535,16 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) +@router.post("/v2/rerank") +@with_cancellation +async def do_rerank_v2(request: RerankRequest, raw_request: Request): + logger.warning( + "To indicate that the rerank API is not part of the standard OpenAI" + " API, we have located it at `/rerank`. Please update your client" + "accordingly. (Note: Conforms to JinaAI rerank API)") + return await do_rerank(request, raw_request) + + TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { "generate": { "messages": (ChatCompletionRequest, create_chat_completion), From b66bcc22dc9faca82dbac6026a8927ea521bb136 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 22:36:41 -0600 Subject: [PATCH 10/23] fix(docs): extra spaces Signed-off-by: Kyle Mistele --- docs/source/serving/openai_compatible_server.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 3f8edfb0e7105..8c54f43146b3f 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -491,8 +491,8 @@ You can find the documentation for these kind of models at [sbert.net](https://w The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the `score` task. Additionally, both `/rerank` and `/v1/rerank` endpoints -endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and -[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with +endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and +[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with popular open-source tools. Code example: From c44dee467738139ec57c6ecc9efb6317c893634b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 22:37:57 -0600 Subject: [PATCH 11/23] fix(docs): cross-reference target for rerank API Signed-off-by: Kyle Mistele --- docs/source/serving/openai_compatible_server.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 8c54f43146b3f..24f3decd49280 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -479,7 +479,7 @@ The following extra parameters are supported: :end-before: end-score-extra-params ``` -(rerank-api) = +(rerank-api)= ### Re-rank API From cce2873e8229bbdf9fb9f73945cb5aca9e9cb256 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 22:39:58 -0600 Subject: [PATCH 12/23] fix(tests): needed to break up model quotes Signed-off-by: Kyle Mistele --- tests/entrypoints/openai/test_rerank.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 0d6cde6c05ad9..d4f9f46676323 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -10,7 +10,7 @@ @pytest.fixture(scope="module") def server(): - args = ['--enforce-eager', '--max-model-len 100'] + args = ["--enforce-eager", "--max-model-len", "100"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server From a38060ff7ab8dbca8680666b06c7a0ae5d5f40dc Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 23:05:11 -0600 Subject: [PATCH 13/23] doc(example): update jina example to reflect lack of SDK, add cohere example Signed-off-by: Kyle Mistele --- .../online_serving/cohere_rerank_client.py | 37 +++++++++++++++++++ .../online_serving/jinjaai_rerank_client.py | 4 +- 2 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 examples/online_serving/cohere_rerank_client.py diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/cohere_rerank_client.py new file mode 100644 index 0000000000000..97e8d40fc9d2c --- /dev/null +++ b/examples/online_serving/cohere_rerank_client.py @@ -0,0 +1,37 @@ +""" +Example of using the OpenAI entrypoint's rerank API which is compatible with +the Cohere SDK: https://github.com/cohere-ai/cohere-python + +run: vllm serve --model BAAI/bge-reranker-base +""" +import cohere + +# cohere v1 client +co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key") +rerank_v1_result = co.rerank( + model="BAAI/bge-reranker-base", + query="What is the capital of France?", + documents=[ + "The capital of France is Paris", + "Reranking is fun!", + "vLLM is an open-source framework for fast AI serving" + ] +) + +print(rerank_v1_result) + +# or the v2 +co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000") + +v2_rerank_result = co2.rerank( + model="BAAI/bge-reranker-base", + query="What is the capital of France?", + documents=[ + "The capital of France is Paris", + "Reranking is fun!", + "vLLM is an open-source framework for fast AI serving" + ] +) + +print(v2_rerank_result) + diff --git a/examples/online_serving/jinjaai_rerank_client.py b/examples/online_serving/jinjaai_rerank_client.py index 84ff7461146eb..aefad48349c6d 100644 --- a/examples/online_serving/jinjaai_rerank_client.py +++ b/examples/online_serving/jinjaai_rerank_client.py @@ -1,6 +1,7 @@ """ Example of using the OpenAI entrypoint's rerank API which is compatible with -Jina and Cohere +Jina and Cohere https://jina.ai/reranker + run: vllm serve --model BAAI/bge-reranker-base """ import json @@ -21,7 +22,6 @@ "The capital of France is Paris.", "Horses and cows are both animals" ] } - response = requests.post(url, headers=headers, json=data) # Check the response From 901021fdbe1e65078f9f12e75e0652570585fbf5 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 23 Jan 2025 23:24:02 -0600 Subject: [PATCH 14/23] fix: remove logger warnings and make the linter happy Signed-off-by: Kyle Mistele --- examples/online_serving/cohere_rerank_client.py | 13 ++++--------- vllm/entrypoints/openai/api_server.py | 8 -------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/cohere_rerank_client.py index 97e8d40fc9d2c..b8934395a7a50 100644 --- a/examples/online_serving/cohere_rerank_client.py +++ b/examples/online_serving/cohere_rerank_client.py @@ -12,11 +12,9 @@ model="BAAI/bge-reranker-base", query="What is the capital of France?", documents=[ - "The capital of France is Paris", - "Reranking is fun!", + "The capital of France is Paris", "Reranking is fun!", "vLLM is an open-source framework for fast AI serving" - ] -) + ]) print(rerank_v1_result) @@ -27,11 +25,8 @@ model="BAAI/bge-reranker-base", query="What is the capital of France?", documents=[ - "The capital of France is Paris", - "Reranking is fun!", + "The capital of France is Paris", "Reranking is fun!", "vLLM is an open-source framework for fast AI serving" - ] -) + ]) print(v2_rerank_result) - diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 41e14af240b38..9df9e405ef1f2 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -528,20 +528,12 @@ async def do_rerank(request: RerankRequest, raw_request: Request): @router.post("/v1/rerank") @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): - logger.warning( - "To indicate that the rerank API is not part of the standard OpenAI" - " API, we have located it at `/rerank`. Please update your client" - "accordingly. (Note: Conforms to JinaAI rerank API)") return await do_rerank(request, raw_request) @router.post("/v2/rerank") @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): - logger.warning( - "To indicate that the rerank API is not part of the standard OpenAI" - " API, we have located it at `/rerank`. Please update your client" - "accordingly. (Note: Conforms to JinaAI rerank API)") return await do_rerank(request, raw_request) From 48495753bd3fc92d2509dc3790c4820ed923e7f1 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 24 Jan 2025 00:20:30 -0600 Subject: [PATCH 15/23] fix: file name Signed-off-by: Kyle Mistele --- .../{jinjaai_rerank_client.py => jinaai_rerank_client.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/online_serving/{jinjaai_rerank_client.py => jinaai_rerank_client.py} (100%) diff --git a/examples/online_serving/jinjaai_rerank_client.py b/examples/online_serving/jinaai_rerank_client.py similarity index 100% rename from examples/online_serving/jinjaai_rerank_client.py rename to examples/online_serving/jinaai_rerank_client.py From 36e85a56623fad9af98edad8873ca95a8821232c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 24 Jan 2025 00:21:31 -0600 Subject: [PATCH 16/23] fix(nit): ordering on assertions Signed-off-by: Kyle Mistele --- tests/entrypoints/openai/test_rerank.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index d4f9f46676323..52ad6e2b8e776 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -36,8 +36,8 @@ def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): assert rerank.id is not None assert rerank.results is not None assert len(rerank.results) == 2 - assert rerank.results[1].relevance_score <= 0.01 assert rerank.results[0].relevance_score >= 0.9 + assert rerank.results[1].relevance_score <= 0.01 @pytest.mark.asyncio @@ -62,8 +62,8 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): assert rerank.id is not None assert rerank.results is not None assert len(rerank.results) == 2 - assert rerank.results[1].relevance_score <= 0.01 assert rerank.results[0].relevance_score >= 0.9 + assert rerank.results[1].relevance_score <= 0.01 @pytest.mark.asyncio From 4adb94b3ead5d7040a4fdaeb486e15470f5cb78a Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 24 Jan 2025 12:59:49 -0600 Subject: [PATCH 17/23] fix(tests): was using score instead of rerank Signed-off-by: Kyle Mistele --- tests/entrypoints/openai/test_rerank.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 0d6cde6c05ad9..8f6cbc35893c4 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -49,7 +49,7 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): "The capital of France is Paris.", "Cross-encoder models are neat" ] - rerank_response = requests.post(server.url_for("score"), + rerank_response = requests.post(server.url_for("rerank"), json={ "model": model_name, "query": query, @@ -68,7 +68,7 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str): +def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" * 100 documents = [ @@ -85,14 +85,3 @@ def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str): # Assert just a small fragments of the response assert "Please reduce the length of the input." in \ rerank_response.text - - # Test truncation - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents - }) - assert rerank_response.status_code == 400 - assert "Please, select a smaller truncation size." in \ - rerank_response.text From dc92240d5be5a59ba88f7a535a68acd98658c5fa Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 24 Jan 2025 13:04:22 -0600 Subject: [PATCH 18/23] fix(api): use rereank as the default API for scoring Signed-off-by: Kyle Mistele --- vllm/entrypoints/openai/api_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7102976b625bc..699910a9a5d72 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -545,9 +545,6 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): "default": (EmbeddingCompletionRequest, create_embedding), }, "score": { - "default": (ScoreRequest, create_score), - }, - "rerank": { "default": (RerankRequest, do_rerank) }, "reward": { From 29a03667d9cb4c66bd6184f217f6134860906900 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 25 Jan 2025 14:10:21 -0600 Subject: [PATCH 19/23] doc: v2 rerank endpoint Signed-off-by: Kyle Mistele --- docs/source/serving/openai_compatible_server.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index e2fab975731d8..8bc234545befd 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -50,7 +50,7 @@ In addition, we have the following custom APIs: - Applicable to all [pooling models](../models/pooling_models.md). - [Score API](#score-api) (`/score`) - Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). -- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`) +- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`) - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) - Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response. @@ -490,7 +490,7 @@ a scale of 0 to 1. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the -`score` task. Additionally, both `/rerank` and `/v1/rerank` +`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank` endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and [Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with popular open-source tools. From 844d39a212a6fc1d283e5b9cdac35381528bffbd Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 25 Jan 2025 14:11:34 -0600 Subject: [PATCH 20/23] fix: remove duplicate file and fix vllm start command in examples Signed-off-by: Kyle Mistele --- .../online_serving/cohere_rerank_client.py | 2 +- .../online_serving/jinaai_rerank_client.py | 2 +- .../online_serving/jinjaai_rerank_client.py | 33 ------------------- 3 files changed, 2 insertions(+), 35 deletions(-) delete mode 100644 examples/online_serving/jinjaai_rerank_client.py diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/cohere_rerank_client.py index b8934395a7a50..a07affe3351ce 100644 --- a/examples/online_serving/cohere_rerank_client.py +++ b/examples/online_serving/cohere_rerank_client.py @@ -2,7 +2,7 @@ Example of using the OpenAI entrypoint's rerank API which is compatible with the Cohere SDK: https://github.com/cohere-ai/cohere-python -run: vllm serve --model BAAI/bge-reranker-base +run: vllm serve BAAI/bge-reranker-base """ import cohere diff --git a/examples/online_serving/jinaai_rerank_client.py b/examples/online_serving/jinaai_rerank_client.py index aefad48349c6d..bf4de76ddf362 100644 --- a/examples/online_serving/jinaai_rerank_client.py +++ b/examples/online_serving/jinaai_rerank_client.py @@ -2,7 +2,7 @@ Example of using the OpenAI entrypoint's rerank API which is compatible with Jina and Cohere https://jina.ai/reranker -run: vllm serve --model BAAI/bge-reranker-base +run: vllm serve BAAI/bge-reranker-base """ import json diff --git a/examples/online_serving/jinjaai_rerank_client.py b/examples/online_serving/jinjaai_rerank_client.py deleted file mode 100644 index 84ff7461146eb..0000000000000 --- a/examples/online_serving/jinjaai_rerank_client.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Example of using the OpenAI entrypoint's rerank API which is compatible with -Jina and Cohere -run: vllm serve --model BAAI/bge-reranker-base -""" -import json - -import requests - -url = "http://127.0.0.1:8000/rerank" - -headers = {"accept": "application/json", "Content-Type": "application/json"} - -data = { - "model": - "BAAI/bge-reranker-base", - "query": - "What is the capital of France?", - "documents": [ - "The capital of Brazil is Brasilia.", - "The capital of France is Paris.", "Horses and cows are both animals" - ] -} - -response = requests.post(url, headers=headers, json=data) - -# Check the response -if response.status_code == 200: - print("Request successful!") - print(json.dumps(response.json(), indent=2)) -else: - print(f"Request failed with status code: {response.status_code}") - print(response.text) From af83c25304be727ff6feb1c2b5e7b0b6934749c8 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 25 Jan 2025 14:13:45 -0600 Subject: [PATCH 21/23] fix: only load serving rerank if model supports score Signed-off-by: Kyle Mistele --- vllm/entrypoints/openai/api_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c25d5412b1410..45cf06566faaa 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -806,7 +806,8 @@ async def init_app_state( engine_client, model_config, state.openai_serving_models, - request_logger=request_logger) + request_logger=request_logger + ) if model_config.task == "score" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, From 17441f5e9116bbac28fe56d74e2623e7c046effd Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 26 Jan 2025 12:45:39 -0600 Subject: [PATCH 22/23] merge Signed-off-by: Kyle Mistele --- tests/entrypoints/openai/test_score.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 06e0f93dbe269..0d19615bc0d99 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -10,12 +10,7 @@ @pytest.fixture(scope="module") def server(): - args = [ - "--enforce-eager", - # Will be used on tests to compare prompt input length - "--max-model-len", - "100" - ] + args = ["--enforce-eager", "--max-model-len", "100"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server From 974c0be65f627cfbb25873921d2823786ed2b421 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 26 Jan 2025 14:53:15 -0600 Subject: [PATCH 23/23] fix(tests): use correct API for rerank tests Signed-off-by: Kyle Mistele --- tests/entrypoints/openai/test_rerank.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 982c2c2328b66..cfd8f33133960 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -49,7 +49,7 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): "The capital of France is Paris.", "Cross-encoder models are neat" ] - rerank_response = requests.post(server.url_for("score"), + rerank_response = requests.post(server.url_for("rerank"), json={ "model": model_name, "query": query,