Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions application/nlq/core/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

from api.schemas import Answer, KnowledgeSearchResult, SQLSearchResult, AgentSearchResult, AskReplayResult, \
AskEntitySelect, ChartEntity
AskEntitySelect, ChartEntity, TaskSQLSearchResult
from nlq.business.datasource.factory import DataSourceFactory
from nlq.business.log_store import LogManagement
from nlq.core.chat_context import ProcessingContext
Expand Down Expand Up @@ -457,22 +457,28 @@ def handle_analyze_data(self):

@log_execution
def handle_agent_task(self):

self.agent_cot_retrieve = get_retrieve_opensearch(self.context.opensearch_info, self.context.query_rewrite,
"agent", self.context.selected_profile, 2, 0.5)

agent_cot_task_result = get_agent_cot_task(self.context.model_type, self.context.database_profile["prompt_map"],
self.context.query_rewrite,
self.context.database_profile['tables_info'],
self.agent_cot_retrieve)
self.agent_task_split = agent_cot_task_result
self.transition(QueryState.AGENT_SEARCH)
# Analyze the task
try:
self.agent_cot_retrieve = get_retrieve_opensearch(self.context.opensearch_info, self.context.query_rewrite,
"agent", self.context.selected_profile, 2, 0.5)

agent_cot_task_result = get_agent_cot_task(self.context.model_type, self.context.database_profile["prompt_map"],
self.context.query_rewrite,
self.context.database_profile['tables_info'],
self.agent_cot_retrieve)
self.agent_task_split = agent_cot_task_result
self.transition(QueryState.AGENT_SEARCH)
except Exception as e:
self.answer.error_log[QueryState.AGENT_TASK.name] = str(e)
logger.error(f"The context is {self.context.search_box}, handle_agent_task encountered an error: {e}")
self.transition(QueryState.ERROR)

@log_execution
def handle_agent_analyze_data(self):
# Analyze the data
try:
filter_deep_dive_sql_result = []
agent_sql_search_result = []
for i in range(len(self.agent_search_result)):
each_task_res = get_sql_result_tool(
self.context.database_profile,
Expand All @@ -489,6 +495,9 @@ def handle_agent_analyze_data(self):
data_show_type="table",
sql_gen_process=each_task_sql_response,
data_analyse="", sql_data_chart=[])
each_task_sql_search_result = TaskSQLSearchResult(sub_task_query=self.agent_search_result[i]["query"],
sql_search_result=sub_task_sql_result)
agent_sql_search_result.append(each_task_sql_search_result)

agent_data_analyse_result = data_analyse_tool(self.context.model_type,
self.context.database_profile["prompt_map"],
Expand All @@ -499,7 +508,7 @@ def handle_agent_analyze_data(self):
self.agent_valid_data = filter_deep_dive_sql_result
self.agent_data_analyse_result = agent_data_analyse_result
self.answer.agent_search_result.agent_summary = agent_data_analyse_result
self.answer.agent_search_result.agent_sql_search_result = None
self.answer.agent_search_result.agent_sql_search_result = agent_sql_search_result
self.transition(QueryState.COMPLETE)
except Exception as e:
self.answer.error_log[QueryState.AGENT_DATA_SUMMARY.name] = str(e)
Expand Down Expand Up @@ -539,7 +548,23 @@ def handle_data_visualization(self):
self.get_answer().sql_search_result.data_show_type = model_select_type
self.get_answer().sql_search_result.sql_data = show_select_data
elif self.answer.query_intent == "agent_search":
pass
agent_sql_search_result = self.answer.agent_search_result.agent_sql_search_result
agent_sql_search_result_with_visualization = []
for each in agent_sql_search_result:
model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(
self.context.model_type,
each.sub_task_query,
each.sql_search_result.sql_data,
self.context.database_profile['prompt_map'])
if select_chart_type != "-1":
sql_chart_data = ChartEntity(chart_type="", chart_data=[])
sql_chart_data.chart_type = select_chart_type
sql_chart_data.chart_data = show_chart_data
each.sql_search_result.sql_data_chart = [sql_chart_data]
each.sql_search_result.data_show_type = model_select_type
each.sql_search_result.sql_data = show_select_data
agent_sql_search_result_with_visualization.append(each)
self.answer.agent_search_result = agent_sql_search_result_with_visualization
except Exception as e:
self.answer.error_log[QueryState.DATA_VISUALIZATION.name] = str(e)
logger.error(
Expand Down