Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion application/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from nlq.business.profile import ProfileManagement
from .enum import ContentEnum, ErrorEnum
from .schemas import Question, QuestionSocket, Answer, Option, CustomQuestion
from .schemas import Question, QuestionSocket, Answer, Option, CustomQuestion, Upvote
from . import service
from nlq.business.nlq_chain import NLQChain
from dotenv import load_dotenv
Expand Down Expand Up @@ -38,6 +38,13 @@ def ask(question: Question):
return service.ask(question)


@router.post("/upvote")
def upvote(upvote_input: Upvote):
upvote_res = service.user_feedback_upvote(upvote_input.data_profiles,upvote_input.query,
upvote_input.query_intent, upvote_input.query_answer_list)
return upvote_res


@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
Expand Down
13 changes: 13 additions & 0 deletions application/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ class Example(BaseModel):
# sql_query_result: list[Any]


class QueryEntity(BaseModel):
query: str
sql: str


class Upvote(BaseModel):
data_profiles: str
query: str
query_intent: str
query_answer_list: list[QueryEntity]


class Option(BaseModel):
data_profiles: list[str]
bedrock_model_ids: list[str]
Expand All @@ -57,6 +69,7 @@ class KnowledgeSearchResult(BaseModel):


class AgentSearchResult(BaseModel):
sub_search_task: list[str]
agent_sql_search_result: list[SQLSearchResult]
agent_summary: str

Expand Down
44 changes: 35 additions & 9 deletions application/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nlq.business.connection import ConnectionManagement
from nlq.business.nlq_chain import NLQChain
from nlq.business.profile import ProfileManagement
from nlq.business.vector_store import VectorStore
from utils.apis import get_sql_result_tool
from utils.database import get_db_url_dialect
from utils.llm import text_to_sql, get_query_intent, create_vector_embedding_with_sagemaker, \
Expand Down Expand Up @@ -182,7 +183,7 @@ def ask(question: Question) -> Answer:
sql_gen_process="",
data_analyse="")

agent_search_response = AgentSearchResult(agent_summary="", agent_sql_search_result=[])
agent_search_response = AgentSearchResult(agent_summary="", agent_sql_search_result=[], sub_search_task=[])

knowledge_search_result = KnowledgeSearchResult(knowledge_response="")

Expand Down Expand Up @@ -221,7 +222,8 @@ def ask(question: Question) -> Answer:

if reject_intent_flag:
answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result,
sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[])
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
suggested_question=[])
return answer
elif search_intent_flag:
normal_search_result = normal_text_search(search_box, model_type,
Expand All @@ -232,8 +234,10 @@ def ask(question: Question) -> Answer:
response = knowledge_search(search_box=search_box, model_id=model_type)

knowledge_search_result.knowledge_response = response
answer = Answer(query=search_box, query_intent="knowledge_search", knowledge_search_result=knowledge_search_result,
sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[])
answer = Answer(query=search_box, query_intent="knowledge_search",
knowledge_search_result=knowledge_search_result,
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
suggested_question=[])
return answer

else:
Expand Down Expand Up @@ -269,38 +273,60 @@ def ask(question: Question) -> Answer:
search_intent_result["data"].to_json(
orient='records', force_ascii=False), "query")
sql_search_result.data_analyse = search_intent_analyse_result
sql_search_result.sql_data = [list(search_intent_result["data"].columns)] +search_intent_result["data"].values.tolist()
sql_search_result.sql_data = [list(search_intent_result["data"].columns)] + search_intent_result[
"data"].values.tolist()

answer = Answer(query=search_box, query_intent="normal_search", knowledge_search_result=knowledge_search_result,
sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[])
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
suggested_question=[])
return answer
else:
sub_search_task = []
for i in range(len(agent_search_result)):
each_task_res = get_sql_result_tool(database_profile, agent_search_result[i]["sql"])
if each_task_res["status_code"] == 200 and len(each_task_res["data"]) > 0:
agent_search_result[i]["data_result"] = each_task_res["data"].to_json(
orient='records')
filter_deep_dive_sql_result.append(agent_search_result[i])
each_task_sql_res = [list(each_task_res["data"].columns)] + each_task_res["data"].values.tolist()
each_task_sql_search_result = SQLSearchResult(query=agent_search_result[i]["query"], sql_data=each_task_sql_res,
each_task_sql_search_result = SQLSearchResult(query=agent_search_result[i]["query"],
sql_data=each_task_sql_res,
sql=each_task_res["sql"], data_show_type="table",
sql_gen_process="",
data_analyse="")
agent_sql_search_result.append(each_task_sql_search_result)

sub_search_task.append(agent_search_result[i]["query"])
agent_data_analyse_result = data_analyse_tool(model_type, search_box,
json.dumps(filter_deep_dive_sql_result, ensure_ascii=False),
"agent")
logger.info("agent_data_analyse_result")
logger.info(agent_data_analyse_result)
agent_search_response.agent_summary = agent_data_analyse_result
agent_search_response.agent_sql_search_result = agent_sql_search_result
agent_search_response.sub_search_task = sub_search_task

answer = Answer(query=search_box, query_intent="agent_search", knowledge_search_result=knowledge_search_result,
sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[])
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
suggested_question=[])
return answer


def user_feedback_upvote(data_profiles: str, query: str, query_intent: str, query_answer_list):
try:
if query_intent == "normal_search":
if len(query_answer_list) > 0:
VectorStore.add_sample(data_profiles, query_answer_list[0].query, query_answer_list[0].sql)
elif query_intent == "agent_search":
query_list = []
for each in query_answer_list:
query_list.append(each.query)
VectorStore.add_sample(data_profiles, each.query, each.sql)
VectorStore.add_agent_cot_sample(data_profiles, query, "\n".join(query_list))
return True
except Exception as e:
return False


def get_nlq_chain(question: Question) -> NLQChain:
logger.debug(question)
verify_parameters(question)
Expand Down
20 changes: 8 additions & 12 deletions application/pages/1_🌍_Generative_BI_Playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def do_visualize_results(nlq_chain, sql_result):
# hacky way to get around the issue of selectbox not updating when the options change
chart_type = visualize_config_columns[0].selectbox('Choose the chart type',
['Table', 'Bar', 'Line', 'Pie'],
on_change=nlq_chain.set_visualization_config_change,
key=random.randint(0, 10000)
on_change=nlq_chain.set_visualization_config_change
)
if chart_type != 'Table':
x_column = visualize_config_columns[1].selectbox(f'Choose x-axis column', available_columns,
Expand Down Expand Up @@ -93,7 +92,7 @@ def recurrent_display(messages, i, current_nlq_chain):
message = messages[i]
if message["type"] == "pandas":
if isinstance(message["content"], pd.DataFrame):
do_visualize_results(current_nlq_chain, message["content"])
st.dataframe(message["content"], hide_index=True)
elif isinstance(message["content"], list):
for each_content in message["content"]:
st.write(each_content["query"])
Expand All @@ -102,13 +101,10 @@ def recurrent_display(messages, i, current_nlq_chain):
st.markdown(message["content"])
elif message["type"] == "error":
st.error(message["content"])
if i + 1 < len(messages):
if current_role != messages[i + 1]["role"]:
return i
else:
return recurrent_display(messages, i + 1, current_nlq_chain)
else:
return i
elif message["type"] == "sql":
with st.expander("The Generate SQL"):
st.code(message["content"], language="sql")
return i


def main():
Expand Down Expand Up @@ -221,8 +217,8 @@ def main():
for i in range(len(st.session_state.messages[selected_profile])):
print('!!!!!')
print(i, new_index)
if i - 1 < new_index:
continue
# if i - 1 < new_index:
# continue
with st.chat_message(st.session_state.messages[selected_profile][i]["role"]):
new_index = recurrent_display(st.session_state.messages[selected_profile], i, current_nlq_chain)

Expand Down