Skip to content

Commit dfdd08c

Browse files
Liangyx2pre-commit-ci[bot]XuehaoSunchensuyue
authored
Use parameter for reranker (#177)
* Use parameter for reranker Signed-off-by: Liangyx2 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Liangyx2 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sun, Xuehao <[email protected]> Co-authored-by: chen, suyue <[email protected]>
1 parent 9e91843 commit dfdd08c

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

comps/cores/proto/docarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class EmbedDoc1024(BaseDoc):
5858
class SearchedDoc(BaseDoc):
5959
retrieved_docs: DocList[TextDoc]
6060
initial_query: str
61+
top_n: int = 1
6162

6263
class Config:
6364
json_encoders = {np.ndarray: lambda x: x.tolist()}

comps/reranks/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,12 @@ curl http://localhost:8000/v1/reranking \
100100
-d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}]}' \
101101
-H 'Content-Type: application/json'
102102
```
103+
104+
You can add the parameter `top_n` to specify the return number of the reranker model, default value is 1.
105+
106+
```bash
107+
curl http://localhost:8000/v1/reranking \
108+
-X POST \
109+
-d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}], "top_n":2}' \
110+
-H 'Content-Type: application/json'
111+
```

comps/reranks/langchain/reranking_tei_xeon.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import heapq
45
import json
56
import os
67
import re
@@ -40,9 +41,11 @@ def reranking(input: SearchedDoc) -> LLMParamsDoc:
4041
headers = {"Content-Type": "application/json"}
4142
response = requests.post(url, data=json.dumps(data), headers=headers)
4243
response_data = response.json()
43-
best_response = max(response_data, key=lambda response: response["score"])
44-
doc = input.retrieved_docs[best_response["index"]]
45-
if doc.text and len(re.findall("[\u4E00-\u9FFF]", doc.text)) / len(doc.text) >= 0.3:
44+
best_response_list = heapq.nlargest(input.top_n, response_data, key=lambda x: x["score"])
45+
context_str = ""
46+
for best_response in best_response_list:
47+
context_str = context_str + " " + input.retrieved_docs[best_response["index"]].text
48+
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3:
4649
# chinese context
4750
template = "仅基于以下背景回答问题:\n{context}\n问题: {question}"
4851
else:
@@ -51,7 +54,7 @@ def reranking(input: SearchedDoc) -> LLMParamsDoc:
5154
Question: {question}
5255
"""
5356
prompt = ChatPromptTemplate.from_template(template)
54-
final_prompt = prompt.format(context=doc.text, question=input.initial_query)
57+
final_prompt = prompt.format(context=context_str, question=input.initial_query)
5558
statistics_dict["opea_service@reranking_tgi_gaudi"].append_latency(time.time() - start, None)
5659
return LLMParamsDoc(query=final_prompt.strip())
5760

0 commit comments

Comments
 (0)