1
1
# Copyright (C) 2024 Intel Corporation
2
2
# SPDX-License-Identifier: Apache-2.0
3
3
4
+ import heapq
4
5
import json
5
6
import os
6
7
import re
@@ -40,9 +41,11 @@ def reranking(input: SearchedDoc) -> LLMParamsDoc:
40
41
headers = {"Content-Type" : "application/json" }
41
42
response = requests .post (url , data = json .dumps (data ), headers = headers )
42
43
response_data = response .json ()
43
- best_response = max (response_data , key = lambda response : response ["score" ])
44
- doc = input .retrieved_docs [best_response ["index" ]]
45
- if doc .text and len (re .findall ("[\u4E00 -\u9FFF ]" , doc .text )) / len (doc .text ) >= 0.3 :
44
+ best_response_list = heapq .nlargest (input .top_n , response_data , key = lambda x : x ["score" ])
45
+ context_str = ""
46
+ for best_response in best_response_list :
47
+ context_str = context_str + " " + input .retrieved_docs [best_response ["index" ]].text
48
+ if context_str and len (re .findall ("[\u4E00 -\u9FFF ]" , context_str )) / len (context_str ) >= 0.3 :
46
49
# chinese context
47
50
template = "仅基于以下背景回答问题:\n {context}\n 问题: {question}"
48
51
else :
@@ -51,7 +54,7 @@ def reranking(input: SearchedDoc) -> LLMParamsDoc:
51
54
Question: {question}
52
55
"""
53
56
prompt = ChatPromptTemplate .from_template (template )
54
- final_prompt = prompt .format (context = doc . text , question = input .initial_query )
57
+ final_prompt = prompt .format (context = context_str , question = input .initial_query )
55
58
statistics_dict ["opea_service@reranking_tgi_gaudi" ].append_latency (time .time () - start , None )
56
59
return LLMParamsDoc (query = final_prompt .strip ())
57
60
0 commit comments