44import pandas as pd
55
66from api .schemas import Answer , KnowledgeSearchResult , SQLSearchResult , AgentSearchResult , AskReplayResult , \
7- AskEntitySelect , ChartEntity
7+ AskEntitySelect , ChartEntity , TaskSQLSearchResult
88from nlq .business .datasource .factory import DataSourceFactory
99from nlq .business .log_store import LogManagement
1010from nlq .core .chat_context import ProcessingContext
@@ -457,22 +457,28 @@ def handle_analyze_data(self):
457457
458458 @log_execution
459459 def handle_agent_task (self ):
460-
461- self .agent_cot_retrieve = get_retrieve_opensearch (self .context .opensearch_info , self .context .query_rewrite ,
462- "agent" , self .context .selected_profile , 2 , 0.5 )
463-
464- agent_cot_task_result = get_agent_cot_task (self .context .model_type , self .context .database_profile ["prompt_map" ],
465- self .context .query_rewrite ,
466- self .context .database_profile ['tables_info' ],
467- self .agent_cot_retrieve )
468- self .agent_task_split = agent_cot_task_result
469- self .transition (QueryState .AGENT_SEARCH )
460+ # Analyze the task
461+ try :
462+ self .agent_cot_retrieve = get_retrieve_opensearch (self .context .opensearch_info , self .context .query_rewrite ,
463+ "agent" , self .context .selected_profile , 2 , 0.5 )
464+
465+ agent_cot_task_result = get_agent_cot_task (self .context .model_type , self .context .database_profile ["prompt_map" ],
466+ self .context .query_rewrite ,
467+ self .context .database_profile ['tables_info' ],
468+ self .agent_cot_retrieve )
469+ self .agent_task_split = agent_cot_task_result
470+ self .transition (QueryState .AGENT_SEARCH )
471+ except Exception as e :
472+ self .answer .error_log [QueryState .AGENT_TASK .name ] = str (e )
473+ logger .error (f"The context is { self .context .search_box } , handle_agent_task encountered an error: { e } " )
474+ self .transition (QueryState .ERROR )
470475
471476 @log_execution
472477 def handle_agent_analyze_data (self ):
473478 # Analyze the data
474479 try :
475480 filter_deep_dive_sql_result = []
481+ agent_sql_search_result = []
476482 for i in range (len (self .agent_search_result )):
477483 each_task_res = get_sql_result_tool (
478484 self .context .database_profile ,
@@ -489,6 +495,9 @@ def handle_agent_analyze_data(self):
489495 data_show_type = "table" ,
490496 sql_gen_process = each_task_sql_response ,
491497 data_analyse = "" , sql_data_chart = [])
498+ each_task_sql_search_result = TaskSQLSearchResult (sub_task_query = self .agent_search_result [i ]["query" ],
499+ sql_search_result = sub_task_sql_result )
500+ agent_sql_search_result .append (each_task_sql_search_result )
492501
493502 agent_data_analyse_result = data_analyse_tool (self .context .model_type ,
494503 self .context .database_profile ["prompt_map" ],
@@ -499,7 +508,7 @@ def handle_agent_analyze_data(self):
499508 self .agent_valid_data = filter_deep_dive_sql_result
500509 self .agent_data_analyse_result = agent_data_analyse_result
501510 self .answer .agent_search_result .agent_summary = agent_data_analyse_result
502- self .answer .agent_search_result .agent_sql_search_result = None
511+ self .answer .agent_search_result .agent_sql_search_result = agent_sql_search_result
503512 self .transition (QueryState .COMPLETE )
504513 except Exception as e :
505514 self .answer .error_log [QueryState .AGENT_DATA_SUMMARY .name ] = str (e )
@@ -539,7 +548,23 @@ def handle_data_visualization(self):
539548 self .get_answer ().sql_search_result .data_show_type = model_select_type
540549 self .get_answer ().sql_search_result .sql_data = show_select_data
541550 elif self .answer .query_intent == "agent_search" :
542- pass
551+ agent_sql_search_result = self .answer .agent_search_result .agent_sql_search_result
552+ agent_sql_search_result_with_visualization = []
553+ for each in agent_sql_search_result :
554+ model_select_type , show_select_data , select_chart_type , show_chart_data = data_visualization (
555+ self .context .model_type ,
556+ each .sub_task_query ,
557+ each .sql_search_result .sql_data ,
558+ self .context .database_profile ['prompt_map' ])
559+ if select_chart_type != "-1" :
560+ sql_chart_data = ChartEntity (chart_type = "" , chart_data = [])
561+ sql_chart_data .chart_type = select_chart_type
562+ sql_chart_data .chart_data = show_chart_data
563+ each .sql_search_result .sql_data_chart = [sql_chart_data ]
564+ each .sql_search_result .data_show_type = model_select_type
565+ each .sql_search_result .sql_data = show_select_data
566+ agent_sql_search_result_with_visualization .append (each )
567+ self .answer .agent_search_result = agent_sql_search_result_with_visualization
543568 except Exception as e :
544569 self .answer .error_log [QueryState .DATA_VISUALIZATION .name ] = str (e )
545570 logger .error (
0 commit comments