Skip to content

Commit 493b45c

Browse files
authored
Merge pull request #304 from aws-samples/v1.7.0_dev
add data visualization
2 parents 5491f6f + ba7ddf2 commit 493b45c

File tree

1 file changed

+38
-13
lines changed

1 file changed

+38
-13
lines changed

application/nlq/core/state_machine.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55

66
from api.schemas import Answer, KnowledgeSearchResult, SQLSearchResult, AgentSearchResult, AskReplayResult, \
7-
AskEntitySelect, ChartEntity
7+
AskEntitySelect, ChartEntity, TaskSQLSearchResult
88
from nlq.business.datasource.factory import DataSourceFactory
99
from nlq.business.log_store import LogManagement
1010
from nlq.core.chat_context import ProcessingContext
@@ -457,22 +457,28 @@ def handle_analyze_data(self):
457457

458458
@log_execution
459459
def handle_agent_task(self):
460-
461-
self.agent_cot_retrieve = get_retrieve_opensearch(self.context.opensearch_info, self.context.query_rewrite,
462-
"agent", self.context.selected_profile, 2, 0.5)
463-
464-
agent_cot_task_result = get_agent_cot_task(self.context.model_type, self.context.database_profile["prompt_map"],
465-
self.context.query_rewrite,
466-
self.context.database_profile['tables_info'],
467-
self.agent_cot_retrieve)
468-
self.agent_task_split = agent_cot_task_result
469-
self.transition(QueryState.AGENT_SEARCH)
460+
# Analyze the task
461+
try:
462+
self.agent_cot_retrieve = get_retrieve_opensearch(self.context.opensearch_info, self.context.query_rewrite,
463+
"agent", self.context.selected_profile, 2, 0.5)
464+
465+
agent_cot_task_result = get_agent_cot_task(self.context.model_type, self.context.database_profile["prompt_map"],
466+
self.context.query_rewrite,
467+
self.context.database_profile['tables_info'],
468+
self.agent_cot_retrieve)
469+
self.agent_task_split = agent_cot_task_result
470+
self.transition(QueryState.AGENT_SEARCH)
471+
except Exception as e:
472+
self.answer.error_log[QueryState.AGENT_TASK.name] = str(e)
473+
logger.error(f"The context is {self.context.search_box}, handle_agent_task encountered an error: {e}")
474+
self.transition(QueryState.ERROR)
470475

471476
@log_execution
472477
def handle_agent_analyze_data(self):
473478
# Analyze the data
474479
try:
475480
filter_deep_dive_sql_result = []
481+
agent_sql_search_result = []
476482
for i in range(len(self.agent_search_result)):
477483
each_task_res = get_sql_result_tool(
478484
self.context.database_profile,
@@ -489,6 +495,9 @@ def handle_agent_analyze_data(self):
489495
data_show_type="table",
490496
sql_gen_process=each_task_sql_response,
491497
data_analyse="", sql_data_chart=[])
498+
each_task_sql_search_result = TaskSQLSearchResult(sub_task_query=self.agent_search_result[i]["query"],
499+
sql_search_result=sub_task_sql_result)
500+
agent_sql_search_result.append(each_task_sql_search_result)
492501

493502
agent_data_analyse_result = data_analyse_tool(self.context.model_type,
494503
self.context.database_profile["prompt_map"],
@@ -499,7 +508,7 @@ def handle_agent_analyze_data(self):
499508
self.agent_valid_data = filter_deep_dive_sql_result
500509
self.agent_data_analyse_result = agent_data_analyse_result
501510
self.answer.agent_search_result.agent_summary = agent_data_analyse_result
502-
self.answer.agent_search_result.agent_sql_search_result = None
511+
self.answer.agent_search_result.agent_sql_search_result = agent_sql_search_result
503512
self.transition(QueryState.COMPLETE)
504513
except Exception as e:
505514
self.answer.error_log[QueryState.AGENT_DATA_SUMMARY.name] = str(e)
@@ -539,7 +548,23 @@ def handle_data_visualization(self):
539548
self.get_answer().sql_search_result.data_show_type = model_select_type
540549
self.get_answer().sql_search_result.sql_data = show_select_data
541550
elif self.answer.query_intent == "agent_search":
542-
pass
551+
agent_sql_search_result = self.answer.agent_search_result.agent_sql_search_result
552+
agent_sql_search_result_with_visualization = []
553+
for each in agent_sql_search_result:
554+
model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(
555+
self.context.model_type,
556+
each.sub_task_query,
557+
each.sql_search_result.sql_data,
558+
self.context.database_profile['prompt_map'])
559+
if select_chart_type != "-1":
560+
sql_chart_data = ChartEntity(chart_type="", chart_data=[])
561+
sql_chart_data.chart_type = select_chart_type
562+
sql_chart_data.chart_data = show_chart_data
563+
each.sql_search_result.sql_data_chart = [sql_chart_data]
564+
each.sql_search_result.data_show_type = model_select_type
565+
each.sql_search_result.sql_data = show_select_data
566+
agent_sql_search_result_with_visualization.append(each)
567+
self.answer.agent_search_result = agent_sql_search_result_with_visualization
543568
except Exception as e:
544569
self.answer.error_log[QueryState.DATA_VISUALIZATION.name] = str(e)
545570
logger.error(

0 commit comments

Comments
 (0)