-
-
Notifications
You must be signed in to change notification settings - Fork 16.2k
[Frontend] Rerank API (Jina- and Cohere-compatible API) #12376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
b6610fb
a82b4bb
99acff6
31b5137
485e328
8922f81
dc0d158
4ed459b
676eea0
b66bcc2
c44dee4
cce2873
a38060f
901021f
4849575
36e85a5
4adb94b
dc92240
330aa22
ce85821
29a0366
844d39a
af83c25
a53b59c
17441f5
974c0be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| """ | ||
| Example of using the OpenAI entrypoint's rerank API which is compatible with | ||
| the Cohere SDK: https://github.com/cohere-ai/cohere-python | ||
|
|
||
| run: vllm serve --model BAAI/bge-reranker-base | ||
| """ | ||
| import cohere | ||
|
|
||
| # cohere v1 client | ||
| co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key") | ||
| rerank_v1_result = co.rerank( | ||
| model="BAAI/bge-reranker-base", | ||
| query="What is the capital of France?", | ||
| documents=[ | ||
| "The capital of France is Paris", | ||
| "Reranking is fun!", | ||
| "vLLM is an open-source framework for fast AI serving" | ||
| ] | ||
| ) | ||
|
|
||
| print(rerank_v1_result) | ||
|
|
||
| # or the v2 | ||
| co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000") | ||
|
|
||
| v2_rerank_result = co2.rerank( | ||
| model="BAAI/bge-reranker-base", | ||
| query="What is the capital of France?", | ||
| documents=[ | ||
| "The capital of France is Paris", | ||
| "Reranking is fun!", | ||
| "vLLM is an open-source framework for fast AI serving" | ||
| ] | ||
| ) | ||
|
|
||
| print(v2_rerank_result) | ||
|
|
|
K-Mistele marked this conversation as resolved.
Outdated
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| """ | ||
| Example of using the OpenAI entrypoint's rerank API which is compatible with | ||
| Jina and Cohere https://jina.ai/reranker | ||
|
|
||
| run: vllm serve --model BAAI/bge-reranker-base | ||
| """ | ||
| import json | ||
|
|
||
| import requests | ||
|
|
||
| url = "http://127.0.0.1:8000/rerank" | ||
|
|
||
| headers = {"accept": "application/json", "Content-Type": "application/json"} | ||
|
K-Mistele marked this conversation as resolved.
|
||
|
|
||
| data = { | ||
| "model": | ||
| "BAAI/bge-reranker-base", | ||
| "query": | ||
| "What is the capital of France?", | ||
| "documents": [ | ||
| "The capital of Brazil is Brasilia.", | ||
| "The capital of France is Paris.", "Horses and cows are both animals" | ||
| ] | ||
| } | ||
| response = requests.post(url, headers=headers, json=data) | ||
|
|
||
| # Check the response | ||
| if response.status_code == 200: | ||
| print("Request successful!") | ||
| print(json.dumps(response.json(), indent=2)) | ||
| else: | ||
| print(f"Request failed with status code: {response.status_code}") | ||
| print(response.text) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| import pytest | ||
| import requests | ||
|
|
||
| from vllm.entrypoints.openai.protocol import RerankResponse | ||
|
|
||
| from ...utils import RemoteOpenAIServer | ||
|
|
||
| MODEL_NAME = "BAAI/bge-reranker-base" | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def server(): | ||
| args = ["--enforce-eager", "--max-model-len", "100"] | ||
|
|
||
| with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: | ||
| yield remote_server | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
| def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): | ||
| query = "What is the capital of France?" | ||
| documents = [ | ||
| "The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
| ] | ||
|
|
||
| rerank_response = requests.post(server.url_for("rerank"), | ||
| json={ | ||
| "model": model_name, | ||
| "query": query, | ||
| "documents": documents, | ||
| }) | ||
| rerank_response.raise_for_status() | ||
| rerank = RerankResponse.model_validate(rerank_response.json()) | ||
|
|
||
| assert rerank.id is not None | ||
| assert rerank.results is not None | ||
| assert len(rerank.results) == 2 | ||
| assert rerank.results[1].relevance_score <= 0.01 | ||
| assert rerank.results[0].relevance_score >= 0.9 | ||
|
DarkLight1337 marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
| def test_top_n(server: RemoteOpenAIServer, model_name: str): | ||
| query = "What is the capital of France?" | ||
| documents = [ | ||
| "The capital of Brazil is Brasilia.", | ||
| "The capital of France is Paris.", "Cross-encoder models are neat" | ||
| ] | ||
|
|
||
| rerank_response = requests.post(server.url_for("score"), | ||
| json={ | ||
| "model": model_name, | ||
| "query": query, | ||
| "documents": documents, | ||
| "top_n": 2 | ||
| }) | ||
| rerank_response.raise_for_status() | ||
| rerank = RerankResponse.model_validate(rerank_response.json()) | ||
|
|
||
| assert rerank.id is not None | ||
| assert rerank.results is not None | ||
| assert len(rerank.results) == 2 | ||
| assert rerank.results[1].relevance_score <= 0.01 | ||
| assert rerank.results[0].relevance_score >= 0.9 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
| def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str): | ||
|
|
||
| query = "What is the capital of France?" * 100 | ||
| documents = [ | ||
| "The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
| ] | ||
|
|
||
| rerank_response = requests.post(server.url_for("rerank"), | ||
| json={ | ||
| "model": model_name, | ||
| "query": query, | ||
| "documents": documents | ||
| }) | ||
| assert rerank_response.status_code == 400 | ||
| # Assert just a small fragments of the response | ||
| assert "Please reduce the length of the input." in \ | ||
| rerank_response.text | ||
|
|
||
| # Test truncation | ||
| rerank_response = requests.post(server.url_for("rerank"), | ||
| json={ | ||
| "model": model_name, | ||
| "query": query, | ||
| "documents": documents | ||
| }) | ||
| assert rerank_response.status_code == 400 | ||
| assert "Please, select a smaller truncation size." in \ | ||
| rerank_response.text | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,7 @@ | |
| PoolingChatRequest, | ||
| PoolingCompletionRequest, | ||
| PoolingRequest, PoolingResponse, | ||
| RerankRequest, RerankResponse, | ||
| ScoreRequest, ScoreResponse, | ||
| TokenizeRequest, | ||
| TokenizeResponse, | ||
|
|
@@ -68,6 +69,7 @@ | |
| from vllm.entrypoints.openai.serving_models import (BaseModelPath, | ||
| OpenAIServingModels) | ||
| from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling | ||
| from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank | ||
| from vllm.entrypoints.openai.serving_score import OpenAIServingScores | ||
| from vllm.entrypoints.openai.serving_tokenization import ( | ||
| OpenAIServingTokenization) | ||
|
|
@@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]: | |
| return request.app.state.openai_serving_scores | ||
|
|
||
|
|
||
| def rerank(request: Request) -> Optional[JinaAIServingRerank]: | ||
| return request.app.state.jinaai_serving_reranking | ||
|
|
||
|
|
||
| def tokenization(request: Request) -> OpenAIServingTokenization: | ||
| return request.app.state.openai_serving_tokenization | ||
|
|
||
|
|
@@ -502,6 +508,43 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): | |
| return await create_score(request, raw_request) | ||
|
|
||
|
|
||
| @router.post("/rerank") | ||
| @with_cancellation | ||
| async def do_rerank(request: RerankRequest, raw_request: Request): | ||
| handler = rerank(raw_request) | ||
| if handler is None: | ||
| return base(raw_request).create_error_response( | ||
| message="The model does not support Rerank (Score) API") | ||
| generator = await handler.do_rerank(request, raw_request) | ||
| if isinstance(generator, ErrorResponse): | ||
| return JSONResponse(content=generator.model_dump(), | ||
| status_code=generator.code) | ||
| elif isinstance(generator, RerankResponse): | ||
| return JSONResponse(content=generator.model_dump()) | ||
|
|
||
| assert_never(generator) | ||
|
|
||
|
|
||
| @router.post("/v1/rerank") | ||
| @with_cancellation | ||
| async def do_rerank_v1(request: RerankRequest, raw_request: Request): | ||
| logger.warning( | ||
| "To indicate that the rerank API is not part of the standard OpenAI" | ||
| " API, we have located it at `/rerank`. Please update your client" | ||
| "accordingly. (Note: Conforms to JinaAI rerank API)") | ||
|
Comment on lines
+531
to
+534
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should remove these warnings as the Cohere Python client will access this URL by default. Unless there's a way to change the URL in the client?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's a way to change the base URL, but that's just the server or hostname. unlike OpenAI which expects you to include the I will remove the logger warnings
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remember to remove or switch to warning_once |
||
| return await do_rerank(request, raw_request) | ||
|
|
||
|
|
||
| @router.post("/v2/rerank") | ||
| @with_cancellation | ||
| async def do_rerank_v2(request: RerankRequest, raw_request: Request): | ||
| logger.warning( | ||
| "To indicate that the rerank API is not part of the standard OpenAI" | ||
| " API, we have located it at `/rerank`. Please update your client" | ||
| "accordingly. (Note: Conforms to JinaAI rerank API)") | ||
| return await do_rerank(request, raw_request) | ||
|
K-Mistele marked this conversation as resolved.
|
||
|
|
||
|
|
||
| TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { | ||
| "generate": { | ||
| "messages": (ChatCompletionRequest, create_chat_completion), | ||
|
|
@@ -514,6 +557,9 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): | |
| "score": { | ||
| "default": (ScoreRequest, create_score), | ||
| }, | ||
| "rerank": { | ||
| "default": (RerankRequest, do_rerank) | ||
| }, | ||
|
K-Mistele marked this conversation as resolved.
|
||
| "reward": { | ||
| "messages": (PoolingChatRequest, create_pooling), | ||
| "default": (PoolingCompletionRequest, create_pooling), | ||
|
|
@@ -759,6 +805,11 @@ async def init_app_state( | |
| state.openai_serving_models, | ||
| request_logger=request_logger | ||
| ) if model_config.task == "score" else None | ||
| state.jinaai_serving_reranking = JinaAIServingRerank( | ||
| engine_client, | ||
| model_config, | ||
| state.openai_serving_models, | ||
| request_logger=request_logger) | ||
|
DarkLight1337 marked this conversation as resolved.
Outdated
|
||
| state.openai_serving_tokenization = OpenAIServingTokenization( | ||
| engine_client, | ||
| model_config, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.