@@ -53,7 +53,13 @@ def upvote_agent_clicked(question, comment, env_vars):
5353def clean_st_history (selected_profile ):
5454 st .session_state .messages [selected_profile ] = []
5555
56- def get_user_history (selected_profile ):
56+
57+ def get_user_history (selected_profile : str ):
58+ """
59+ get user history for selected profile
60+ :param selected_profile:
61+ :return: history for selected profile list type
62+ """
5763 history_list = st .session_state .messages [selected_profile ]
5864 history_query = []
5965 for messages in history_list :
@@ -62,6 +68,7 @@ def get_user_history(selected_profile):
6268 history_query .append (messages ["content" ])
6369 return history_query
6470
71+
6572def do_visualize_results (nlq_chain , sql_result ):
6673 sql_query_result = sql_result
6774 if sql_query_result is not None :
@@ -98,7 +105,7 @@ def do_visualize_results(nlq_chain, sql_result):
98105 st .markdown ('No visualization generated.' )
99106
100107
101- def recurrent_display (messages , i , current_nlq_chain ):
108+ def recurrent_display (messages , i ):
102109 # hacking way of displaying messages, since the chat_message does not support multiple messages outside of "with" statement
103110 current_role = messages [i ]["role" ]
104111 message = messages [i ]
@@ -115,12 +122,13 @@ def recurrent_display(messages, i, current_nlq_chain):
115122 st .error (message ["content" ])
116123 elif message ["type" ] == "sql" :
117124 with st .expander ("The Generate SQL" ):
118- st .code (message ["content" ], language = "sql" )
125+ st .code (message ["content" ], language = "sql" )
119126 return i
120127
121128
122- def normal_text_search_streamlit (search_box , model_type , database_profile , entity_slot , env_vars , selected_profile , use_rag ,
123- model_provider = None ):
129+ def normal_text_search_streamlit (search_box , model_type , database_profile , entity_slot , env_vars , selected_profile ,
130+ use_rag ,
131+ model_provider = None ):
124132 entity_slot_retrieve = []
125133 retrieve_result = []
126134 response = ""
@@ -154,7 +162,7 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
154162 with st .status ("Performing QA retrieval..." ) as status_text :
155163 if use_rag :
156164 retrieve_result = get_retrieve_opensearch (env_vars , search_box , "query" ,
157- selected_profile , 3 , 0.5 )
165+ selected_profile , 3 , 0.5 )
158166 examples = []
159167 for example in retrieve_result :
160168 examples .append ({'Score' : example ['_score' ],
@@ -167,14 +175,14 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
167175
168176 with st .status ("Generating SQL... " ) as status_text :
169177 response = text_to_sql (database_profile ['tables_info' ],
170- database_profile ['hints' ],
171- database_profile ['prompt_map' ],
172- search_box ,
173- model_id = model_type ,
174- sql_examples = retrieve_result ,
175- ner_example = entity_slot_retrieve ,
176- dialect = database_profile ['db_type' ],
177- model_provider = model_provider )
178+ database_profile ['hints' ],
179+ database_profile ['prompt_map' ],
180+ search_box ,
181+ model_id = model_type ,
182+ sql_examples = retrieve_result ,
183+ ner_example = entity_slot_retrieve ,
184+ dialect = database_profile ['db_type' ],
185+ model_provider = model_provider )
178186
179187 sql = get_generated_sql (response )
180188
@@ -193,9 +201,8 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
193201 pass
194202
195203 status_text .update (
196- label = f"Generating SQL Done" ,
197- state = "complete" , expanded = True )
198-
204+ label = f"Generating SQL Done" ,
205+ state = "complete" , expanded = True )
199206
200207 search_result = SearchTextSqlResult (search_query = search_box , entity_slot_retrieve = entity_slot_retrieve ,
201208 retrieve_result = retrieve_result , response = response , sql = "" )
@@ -207,10 +214,10 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
207214 logger .error (e )
208215 return search_result
209216
217+
210218def main ():
211219 load_dotenv ()
212220
213- # load config.json as dictionary
214221 with open (os .path .join (os .getcwd (), 'config_files' , '1_config.json' )) as f :
215222 env_vars = json .load (f )
216223 opensearch_config = env_vars ['data_sources' ]['shopping_guide' ]['opensearch' ]
@@ -243,12 +250,6 @@ def main():
243250 if 'selected_sample' not in st .session_state :
244251 st .session_state ['selected_sample' ] = ''
245252
246- if 'dataframe' not in st .session_state :
247- st .session_state ['dataframe' ] = pd .DataFrame ({
248- 'column1' : ['A' , 'B' , 'C' ],
249- 'column2' : [1 , 2 , 3 ]
250- })
251-
252253 if 'current_profile' not in st .session_state :
253254 st .session_state ['current_profile' ] = ''
254255
@@ -336,7 +337,7 @@ def main():
336337 # if i - 1 < new_index:
337338 # continue
338339 with st .chat_message (st .session_state .messages [selected_profile ][i ]["role" ]):
339- new_index = recurrent_display (st .session_state .messages [selected_profile ], i , current_nlq_chain )
340+ new_index = recurrent_display (st .session_state .messages [selected_profile ], i )
340341
341342 text_placeholder = "Type your query here..."
342343
@@ -360,9 +361,6 @@ def main():
360361 {"role" : "user" , "content" : search_box , "type" : "text" })
361362 st .markdown (current_nlq_chain .get_question ())
362363 with st .chat_message ("assistant" ):
363- # retrieve_result = []
364- # entity_slot_retrieve = []
365- # deep_dive_sql_result = []
366364 filter_deep_dive_sql_result = []
367365 entity_slot = []
368366 normal_search_result = None
@@ -384,8 +382,7 @@ def main():
384382 prompt_map [key ] = prompt_map_dict [key ]
385383 ProfileManagement .update_table_prompt_map (selected_profile , prompt_map )
386384
387-
388- # 多轮对话,query改写
385+ # Multiple rounds of dialogue, query rewriting
389386 user_query_history = get_user_history (selected_profile )
390387 if len (user_query_history ) > 0 :
391388 with st .status ("Query Context Understanding" ) as status_text :
@@ -395,7 +392,7 @@ def main():
395392 logger .info ("The Origin query is {query} query rewrite is {new_query}" .format (query = search_box ,
396393 new_query = new_search_box ))
397394 search_box = new_search_box
398- st .write (search_box )
395+ st .write (search_box )
399396 status_text .update (label = f"Query Context Rewrite Completed" , state = "complete" , expanded = False )
400397 intent_response = {
401398 "intent" : "normal_search" ,
@@ -437,10 +434,10 @@ def main():
437434 elif search_intent_flag :
438435 # 执行普通的查询,并可视化结果
439436 normal_search_result = normal_text_search_streamlit (search_box , model_type ,
440- database_profile ,
441- entity_slot , env_vars ,
442- selected_profile ,
443- explain_gen_process_flag , use_rag_flag )
437+ database_profile ,
438+ entity_slot , env_vars ,
439+ selected_profile ,
440+ explain_gen_process_flag , use_rag_flag )
444441 elif knowledge_search_flag :
445442 with st .spinner ('Performing knowledge search...' ):
446443 response = knowledge_search (search_box = search_box , model_id = model_type ,
@@ -459,8 +456,8 @@ def main():
459456 agent_examples = []
460457 for example in agent_cot_retrieve :
461458 agent_examples .append ({'Score' : example ['_score' ],
462- 'Question' : example ['_source' ]['query' ],
463- 'Answer' : example ['_source' ]['comment' ].strip ()})
459+ 'Question' : example ['_source' ]['query' ],
460+ 'Answer' : example ['_source' ]['comment' ].strip ()})
464461 st .write (agent_examples )
465462 with st .expander (f'Agent Task : { len (agent_cot_task_result )} ' ):
466463 st .write (agent_cot_task_result )
@@ -505,7 +502,8 @@ def main():
505502 with st .expander ("The SQL Error Info" ):
506503 st .markdown (search_intent_result ["error_info" ])
507504 else :
508- if search_intent_result ["data" ] is not None and len (search_intent_result ["data" ]) > 0 and data_with_analyse :
505+ if search_intent_result ["data" ] is not None and len (
506+ search_intent_result ["data" ]) > 0 and data_with_analyse :
509507 with st .spinner ('Generating data summarize...' ):
510508 search_intent_analyse_result = data_analyse_tool (model_type , prompt_map , search_box ,
511509 search_intent_result ["data" ].to_json (
0 commit comments