Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions application/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,6 @@ async def websocket_endpoint(websocket: WebSocket):
session_id = question.session_id
ask_result = await ask_websocket(websocket, question)
logger.info(ask_result)


# current_nlq_chain = service.get_nlq_chain(question)
# if question.use_rag:
# examples = service.get_example(current_nlq_chain)
# await response_websocket(websocket, session_id, "Examples:\n```json\n")
# await response_websocket(websocket, session_id, str(examples))
# await response_websocket(websocket, session_id, "\n```\n")
# response = service.ask_with_response_stream(question, current_nlq_chain)
# if os.getenv('SAGEMAKER_ENDPOINT_SQL', ''):
# await response_sagemaker_sql(websocket, session_id, response, current_nlq_chain)
# await response_websocket(websocket, session_id, "\n")
# explain_response = service.explain_with_response_stream(current_nlq_chain)
# await response_sagemaker_explain(websocket, session_id, explain_response)
# else:
# await response_bedrock(websocket, session_id, response, current_nlq_chain)
#
# if question.query_result:
# final_sql_query_result = service.get_executed_result(current_nlq_chain)
# await response_websocket(websocket, session_id, "\n\nQuery result: \n")
# await response_websocket(websocket, session_id, final_sql_query_result)
# await response_websocket(websocket, session_id, "\n")
await response_websocket(websocket, session_id, ask_result.dict(), ContentEnum.END)
except Exception:
msg = traceback.format_exc()
Expand Down
9 changes: 4 additions & 5 deletions application/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,13 @@ def ask(question: Question) -> Answer:
split_strings = generated_sq.split("[generate]")
generate_suggested_question_list = [s.strip() for s in split_strings if s.strip()]

# 连接数据库,执行SQL, 记录历史记录并展示
if search_intent_flag:
if normal_search_result.sql != "":
current_nlq_chain.set_generated_sql(normal_search_result.sql)
sql_search_result.sql = normal_search_result.sql
sql_search_result.sql = normal_search_result.sql.strip()
current_nlq_chain.set_generated_sql_response(normal_search_result.response)
if explain_gen_process_flag:
sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain()
sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain().strip()
else:
sql_search_result.sql = "-1"

Expand Down Expand Up @@ -484,10 +483,10 @@ async def ask_websocket(websocket: WebSocket, question: Question):
if search_intent_flag:
if normal_search_result.sql != "":
current_nlq_chain.set_generated_sql(normal_search_result.sql)
sql_search_result.sql = normal_search_result.sql
sql_search_result.sql = normal_search_result.sql.strip()
current_nlq_chain.set_generated_sql_response(normal_search_result.response)
if explain_gen_process_flag:
sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain()
sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain().strip()
else:
sql_search_result.sql = "-1"

Expand Down