77from nlq .business .connection import ConnectionManagement
88from nlq .business .nlq_chain import NLQChain
99from nlq .business .profile import ProfileManagement
10+ from nlq .business .vector_store import VectorStore
1011from utils .apis import get_sql_result_tool
1112from utils .database import get_db_url_dialect
1213from utils .llm import text_to_sql , get_query_intent , create_vector_embedding_with_sagemaker , \
@@ -182,7 +183,7 @@ def ask(question: Question) -> Answer:
182183 sql_gen_process = "" ,
183184 data_analyse = "" )
184185
185- agent_search_response = AgentSearchResult (agent_summary = "" , agent_sql_search_result = [])
186+ agent_search_response = AgentSearchResult (agent_summary = "" , agent_sql_search_result = [], sub_search_task = [] )
186187
187188 knowledge_search_result = KnowledgeSearchResult (knowledge_response = "" )
188189
@@ -221,7 +222,8 @@ def ask(question: Question) -> Answer:
221222
222223 if reject_intent_flag :
223224 answer = Answer (query = search_box , query_intent = "reject_search" , knowledge_search_result = knowledge_search_result ,
224- sql_search_result = sql_search_result , agent_search_result = agent_search_response , suggested_question = [])
225+ sql_search_result = sql_search_result , agent_search_result = agent_search_response ,
226+ suggested_question = [])
225227 return answer
226228 elif search_intent_flag :
227229 normal_search_result = normal_text_search (search_box , model_type ,
@@ -232,8 +234,10 @@ def ask(question: Question) -> Answer:
232234 response = knowledge_search (search_box = search_box , model_id = model_type )
233235
234236 knowledge_search_result .knowledge_response = response
235- answer = Answer (query = search_box , query_intent = "knowledge_search" , knowledge_search_result = knowledge_search_result ,
236- sql_search_result = sql_search_result , agent_search_result = agent_search_response , suggested_question = [])
237+ answer = Answer (query = search_box , query_intent = "knowledge_search" ,
238+ knowledge_search_result = knowledge_search_result ,
239+ sql_search_result = sql_search_result , agent_search_result = agent_search_response ,
240+ suggested_question = [])
237241 return answer
238242
239243 else :
@@ -269,38 +273,60 @@ def ask(question: Question) -> Answer:
269273 search_intent_result ["data" ].to_json (
270274 orient = 'records' , force_ascii = False ), "query" )
271275 sql_search_result .data_analyse = search_intent_analyse_result
272- sql_search_result .sql_data = [list (search_intent_result ["data" ].columns )] + search_intent_result ["data" ].values .tolist ()
276+ sql_search_result .sql_data = [list (search_intent_result ["data" ].columns )] + search_intent_result [
277+ "data" ].values .tolist ()
273278
274279 answer = Answer (query = search_box , query_intent = "normal_search" , knowledge_search_result = knowledge_search_result ,
275- sql_search_result = sql_search_result , agent_search_result = agent_search_response , suggested_question = [])
280+ sql_search_result = sql_search_result , agent_search_result = agent_search_response ,
281+ suggested_question = [])
276282 return answer
277283 else :
284+ sub_search_task = []
278285 for i in range (len (agent_search_result )):
279286 each_task_res = get_sql_result_tool (database_profile , agent_search_result [i ]["sql" ])
280287 if each_task_res ["status_code" ] == 200 and len (each_task_res ["data" ]) > 0 :
281288 agent_search_result [i ]["data_result" ] = each_task_res ["data" ].to_json (
282289 orient = 'records' )
283290 filter_deep_dive_sql_result .append (agent_search_result [i ])
284291 each_task_sql_res = [list (each_task_res ["data" ].columns )] + each_task_res ["data" ].values .tolist ()
285- each_task_sql_search_result = SQLSearchResult (query = agent_search_result [i ]["query" ], sql_data = each_task_sql_res ,
292+ each_task_sql_search_result = SQLSearchResult (query = agent_search_result [i ]["query" ],
293+ sql_data = each_task_sql_res ,
286294 sql = each_task_res ["sql" ], data_show_type = "table" ,
287295 sql_gen_process = "" ,
288296 data_analyse = "" )
289297 agent_sql_search_result .append (each_task_sql_search_result )
290-
298+ sub_search_task . append ( agent_search_result [ i ][ "query" ])
291299 agent_data_analyse_result = data_analyse_tool (model_type , search_box ,
292300 json .dumps (filter_deep_dive_sql_result , ensure_ascii = False ),
293301 "agent" )
294302 logger .info ("agent_data_analyse_result" )
295303 logger .info (agent_data_analyse_result )
296304 agent_search_response .agent_summary = agent_data_analyse_result
297305 agent_search_response .agent_sql_search_result = agent_sql_search_result
306+ agent_search_response .sub_search_task = sub_search_task
298307
299308 answer = Answer (query = search_box , query_intent = "agent_search" , knowledge_search_result = knowledge_search_result ,
300- sql_search_result = sql_search_result , agent_search_result = agent_search_response , suggested_question = [])
309+ sql_search_result = sql_search_result , agent_search_result = agent_search_response ,
310+ suggested_question = [])
301311 return answer
302312
303313
314+ def user_feedback_upvote (data_profiles : str , query : str , query_intent : str , query_answer_list ):
315+ try :
316+ if query_intent == "normal_search" :
317+ if len (query_answer_list ) > 0 :
318+ VectorStore .add_sample (data_profiles , query_answer_list [0 ].query , query_answer_list [0 ].sql )
319+ elif query_intent == "agent_search" :
320+ query_list = []
321+ for each in query_answer_list :
322+ query_list .append (each .query )
323+ VectorStore .add_sample (data_profiles , each .query , each .sql )
324+ VectorStore .add_agent_cot_sample (data_profiles , query , "\n " .join (query_list ))
325+ return True
326+ except Exception as e :
327+ return False
328+
329+
304330def get_nlq_chain (question : Question ) -> NLQChain :
305331 logger .debug (question )
306332 verify_parameters (question )
0 commit comments