Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion application/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def user_feedback(input_data: FeedBackInput):
return upvote_res
else:
downvote_res = service.user_feedback_downvote(input_data.data_profiles, user_id, session_id, input_data.query,
input_data.query_intent, input_data.query_answer)
input_data.query_intent, input_data.query_answer,
input_data.error_description, input_data.error_categories,
input_data.correct_sql_reference)
return downvote_res


Expand Down
3 changes: 3 additions & 0 deletions application/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class FeedBackInput(BaseModel):
query_answer: str
session_id: str = "-1"
user_id: str = "admin"
error_description: str = ""
error_categories: str = ""
correct_sql_reference: str = ""


class Option(BaseModel):
Expand Down
19 changes: 14 additions & 5 deletions application/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from nlq.business.connection import ConnectionManagement
from nlq.business.datasource.factory import DataSourceFactory
from nlq.business.log_feedback import FeedBackManagement
from nlq.business.model import ModelManagement
from nlq.business.nlq_chain import NLQChain
from nlq.business.profile import ProfileManagement
Expand Down Expand Up @@ -157,6 +158,7 @@ async def ask_websocket(websocket: WebSocket, question: Question):
user_query_history = LogManagement.get_history_by_session(profile_name=selected_profile, user_id=user_id,
session_id=session_id, size=context_window,
log_type='chat_history')
user_query_history.append("user:" + search_box)

if question.previous_intent == "entity_select":
previous_state = QueryState.USER_SELECT_ENTITY.name
Expand All @@ -167,6 +169,7 @@ async def ask_websocket(websocket: WebSocket, question: Question):
query_rewrite=question.query_rewrite,
session_id=session_id,
user_id=user_id,
username=username,
selected_profile=selected_profile,
database_profile=database_profile,
model_type=model_type,
Expand Down Expand Up @@ -296,26 +299,32 @@ def user_feedback_upvote(data_profiles: str, user_id: str, session_id: str, quer


def user_feedback_downvote(data_profiles: str, user_id: str, session_id: str, query: str, query_intent: str,
query_answer):
query_answer, error_description="", error_categories="", correct_sql_reference=""):
try:
error_info_dict = {
"error_description": error_description,
"error_categories": error_categories,
"correct_sql_reference": correct_sql_reference
}
error_info = json.dumps(error_info_dict, ensure_ascii=False)
if query_intent == "normal_search":
log_id = generate_log_id()
current_time = get_current_time()
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id,
FeedBackManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id,
profile_name=data_profiles,
sql=query_answer, query=query,
intent="normal_search_user_downvote",
log_info="",
log_info=error_info,
time_str=current_time,
log_type="feedback_downvote")
elif query_intent == "agent_search":
log_id = generate_log_id()
current_time = get_current_time()
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id,
FeedBackManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id,
profile_name=data_profiles,
sql=query_answer, query=query,
intent="agent_search_user_downvote",
log_info="",
log_info=error_info,
time_str=current_time,
log_type="feedback_downvote")
return True
Expand Down
5 changes: 3 additions & 2 deletions application/nlq/business/connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging

from nlq.data_access.dynamo_connection import ConnectConfigDao, ConnectConfigEntity
from nlq.data_access.database import RelationDatabase
from utils.logging import getLogger

logger = logging.getLogger(__name__)
logger = getLogger()


class ConnectionManagement:
Expand Down
18 changes: 18 additions & 0 deletions application/nlq/business/log_feedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json
import logging

from nlq.data_access.dynamo_query_log import DynamoQueryLogDao
from utils.logging import getLogger

logger = getLogger()


class FeedBackManagement:
dynammo_log_dao = DynamoQueryLogDao()

@classmethod
def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str,
log_type='chat_history'):
cls.dynammo_log_dao.add_log(log_id=log_id, profile_name=profile_name, user_id=user_id, session_id=session_id,
sql=sql, query=query, intent=intent, log_info=log_info, time_str=time_str,
log_type=log_type)
4 changes: 2 additions & 2 deletions application/nlq/business/log_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import logging

from nlq.data_access.opensearch_query_log import OpenSearchQueryLogDao
from utils.logging import getLogger

logger = logging.getLogger(__name__)
logger = getLogger()


class LogManagement:
Expand Down
5 changes: 3 additions & 2 deletions application/nlq/business/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging

from nlq.data_access.dynamo_model import ModelConfigDao, ModelConfigEntity
from utils.logging import getLogger

logger = logging.getLogger(__name__)
logger = getLogger()


class ModelManagement:
Expand Down
5 changes: 3 additions & 2 deletions application/nlq/business/profile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging

from nlq.data_access.dynamo_profile import ProfileConfigDao, ProfileConfigEntity
from utils.logging import getLogger

logger = logging.getLogger(__name__)
logger = getLogger()

class ProfileManagement:
profile_config_dao = ProfileConfigDao()
Expand Down
6 changes: 3 additions & 3 deletions application/nlq/business/vector_store.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import os

import boto3
import json
from nlq.data_access.opensearch import OpenSearchDao
from utils.env_var import BEDROCK_REGION, AOS_HOST, AOS_PORT, AOS_USER, AOS_PASSWORD, opensearch_info, \
SAGEMAKER_ENDPOINT_EMBEDDING
from utils.env_var import bedrock_ak_sk_info
from utils.llm import invoke_model_sagemaker_endpoint
from utils.logging import getLogger

logger = logging.getLogger(__name__)
logger = getLogger()


class VectorStore:
Expand Down
1 change: 1 addition & 0 deletions application/nlq/core/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ProcessingContext:
query_rewrite: str
session_id: str
user_id: str
username: str
selected_profile: str
database_profile: Dict[str, Any]
model_type: str
Expand Down
57 changes: 41 additions & 16 deletions application/nlq/core/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

from api.schemas import Answer, KnowledgeSearchResult, SQLSearchResult, AgentSearchResult, AskReplayResult, \
AskEntitySelect, ChartEntity
AskEntitySelect, ChartEntity, TaskSQLSearchResult
from nlq.business.datasource.factory import DataSourceFactory
from nlq.business.log_store import LogManagement
from nlq.core.chat_context import ProcessingContext
Expand Down Expand Up @@ -244,8 +244,8 @@ def handle_sql_generation(self):
self.intent_search_result["sql"] = sql
self.intent_search_result["response"] = response
self.intent_search_result["original_sql"] = original_sql
self.answer.sql_search_result.sql = sql
self.answer.sql_search_result.sql_gen_process = get_generated_sql(response)
self.answer.sql_search_result.sql = sql.strip()
self.answer.sql_search_result.sql_gen_process = get_generated_sql_explain(response).strip()
if self.context.visualize_results_flag:
self.transition(QueryState.EXECUTE_QUERY)
else:
Expand All @@ -256,7 +256,7 @@ def _apply_row_level_security_for_sql(self, sql):
self.context.database_profile['db_type'],
sql,
self.context.database_profile['row_level_security_config'],
self.context.user_id
self.context.username
)
return post_sql

Expand Down Expand Up @@ -457,22 +457,28 @@ def handle_analyze_data(self):

@log_execution
def handle_agent_task(self):

self.agent_cot_retrieve = get_retrieve_opensearch(self.context.opensearch_info, self.context.query_rewrite,
"agent", self.context.selected_profile, 2, 0.5)

agent_cot_task_result = get_agent_cot_task(self.context.model_type, self.context.database_profile["prompt_map"],
self.context.query_rewrite,
self.context.database_profile['tables_info'],
self.agent_cot_retrieve)
self.agent_task_split = agent_cot_task_result
self.transition(QueryState.AGENT_SEARCH)
# Analyze the task
try:
self.agent_cot_retrieve = get_retrieve_opensearch(self.context.opensearch_info, self.context.query_rewrite,
"agent", self.context.selected_profile, 2, 0.5)

agent_cot_task_result = get_agent_cot_task(self.context.model_type, self.context.database_profile["prompt_map"],
self.context.query_rewrite,
self.context.database_profile['tables_info'],
self.agent_cot_retrieve)
self.agent_task_split = agent_cot_task_result
self.transition(QueryState.AGENT_SEARCH)
except Exception as e:
self.answer.error_log[QueryState.AGENT_TASK.name] = str(e)
logger.error(f"The context is {self.context.search_box}, handle_agent_task encountered an error: {e}")
self.transition(QueryState.ERROR)

@log_execution
def handle_agent_analyze_data(self):
# Analyze the data
try:
filter_deep_dive_sql_result = []
agent_sql_search_result = []
for i in range(len(self.agent_search_result)):
each_task_res = get_sql_result_tool(
self.context.database_profile,
Expand All @@ -489,6 +495,9 @@ def handle_agent_analyze_data(self):
data_show_type="table",
sql_gen_process=each_task_sql_response,
data_analyse="", sql_data_chart=[])
each_task_sql_search_result = TaskSQLSearchResult(sub_task_query=self.agent_search_result[i]["query"],
sql_search_result=sub_task_sql_result)
agent_sql_search_result.append(each_task_sql_search_result)

agent_data_analyse_result = data_analyse_tool(self.context.model_type,
self.context.database_profile["prompt_map"],
Expand All @@ -499,7 +508,7 @@ def handle_agent_analyze_data(self):
self.agent_valid_data = filter_deep_dive_sql_result
self.agent_data_analyse_result = agent_data_analyse_result
self.answer.agent_search_result.agent_summary = agent_data_analyse_result
self.answer.agent_search_result.agent_sql_search_result = None
self.answer.agent_search_result.agent_sql_search_result = agent_sql_search_result
self.transition(QueryState.COMPLETE)
except Exception as e:
self.answer.error_log[QueryState.AGENT_DATA_SUMMARY.name] = str(e)
Expand Down Expand Up @@ -539,7 +548,23 @@ def handle_data_visualization(self):
self.get_answer().sql_search_result.data_show_type = model_select_type
self.get_answer().sql_search_result.sql_data = show_select_data
elif self.answer.query_intent == "agent_search":
pass
agent_sql_search_result = self.answer.agent_search_result.agent_sql_search_result
agent_sql_search_result_with_visualization = []
for each in agent_sql_search_result:
model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(
self.context.model_type,
each.sub_task_query,
each.sql_search_result.sql_data,
self.context.database_profile['prompt_map'])
if select_chart_type != "-1":
sql_chart_data = ChartEntity(chart_type="", chart_data=[])
sql_chart_data.chart_type = select_chart_type
sql_chart_data.chart_data = show_chart_data
each.sql_search_result.sql_data_chart = [sql_chart_data]
each.sql_search_result.data_show_type = model_select_type
each.sql_search_result.sql_data = show_select_data
agent_sql_search_result_with_visualization.append(each)
self.answer.agent_search_result = agent_sql_search_result_with_visualization
except Exception as e:
self.answer.error_log[QueryState.DATA_VISUALIZATION.name] = str(e)
logger.error(
Expand Down
35 changes: 33 additions & 2 deletions application/nlq/data_access/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def get_table_definition_by_connection(cls, connection: ConnectConfigEntity, sch
metadata = cls.get_metadata_by_connection(connection, schemas)
tables = metadata.tables
table_info = {}
if connection.db_type == 'hive':
tables_comment = cls.get_hive_table_comment(connection, table_names)
else:
tables_comment = {}

for table_name, table in tables.items():
# If table name is provided, only generate DDL for those tables. Otherwise, generate DDL for all tables.
Expand All @@ -129,20 +133,47 @@ def get_table_definition_by_connection(cls, connection: ConnectConfigEntity, sch
# Start the DDL statement
table_comment = f'-- {table.comment}' if table.comment else ''
ddl = f"CREATE TABLE {table_name} {table_comment} \n (\n"

if table_name in tables_comment:
column_comment_value = tables_comment[table_name]
else:
column_comment_value = {}
for column in table.columns:
column: Column
# get column description
column_comment = f'-- {column.comment}' if column.comment else ''
if column.comment is None:
if column.name in column_comment_value:
column.comment = column_comment_value[column.name]
column_comment = f'COMMENT {column.comment}' if column.comment else ''
ddl += f" {column.name} {column.type.__visit_name__} {column_comment},\n"
ddl = ddl.rstrip(',\n') + "\n)" # Remove the last comma and close the CREATE TABLE statement
table_info[table_name] = {}
table_info[table_name]['ddl'] = ddl
table_info[table_name]['description'] = table.comment

logger.info(f'added table {table_name} to table_info dict')

return table_info

@classmethod
def get_hive_table_comment(cls, connection, table_names):
table_name_comment = {}
try:
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
connection.db_port, connection.db_name)
engine = db.create_engine(db_url)
for each_table in table_names:
table_name_comment[each_table] = {}
with engine.connect() as connection:
sql = "describe " + each_table
result = connection.execute(sql)
for row in result:
if len(row) == 3:
table_name_comment[each_table][row[0]] = "'" + row[2] + "'"
return table_name_comment
except Exception as e:
logger.error(f"Failed to get table comment: {str(e)}")
return table_name_comment

@classmethod
def get_db_url_by_connection(cls, connection: ConnectConfigEntity):
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
Expand Down
6 changes: 4 additions & 2 deletions application/pages/1_🌍_Generative_BI_Playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def main():
# Title and Description
st.subheader('Generative BI Playground')

demo_profile_suffix = '(demo)'
st.write('Current Username: ' + st.session_state['auth_username'])

# Initialize or set up state variables

if "update_profile" not in st.session_state:
Expand Down Expand Up @@ -340,7 +341,8 @@ def main():
search_box=search_box,
query_rewrite="",
session_id="",
user_id="",
user_id=st.session_state['auth_username'],
username=st.session_state['auth_username'],
selected_profile=selected_profile,
database_profile=database_profile,
model_type=model_type,
Expand Down
1 change: 1 addition & 0 deletions application/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def invoke_llm_model(model_id, system_prompt, user_prompt, max_tokens=2048, with
logger.info(f'{body=}')
endpoint_name = model_id[len('sagemaker.'):]
response = invoke_model_sagemaker_endpoint(endpoint_name, body, "LLM", with_response_stream)
logger.info(f'{response=}')
if with_response_stream:
return response
else:
Expand Down
2 changes: 2 additions & 0 deletions application/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def getLogger():
# 设置日志处理器格式
console_handler.setFormatter(formatter)

# 清理旧的日志处理器
logger.handlers.clear()
# 添加日志处理器
logger.addHandler(console_handler)

Expand Down
Loading