Skip to content

Commit 2916746

Browse files
authored
Merge pull request #170 from aws-samples/spy_dev
remove some code and replace space in SQL
2 parents 96ac580 + 2551c9d commit 2916746

File tree

2 files changed

+4
-27
lines changed

2 files changed

+4
-27
lines changed

application/api/main.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,6 @@ async def websocket_endpoint(websocket: WebSocket):
6565
session_id = question.session_id
6666
ask_result = await ask_websocket(websocket, question)
6767
logger.info(ask_result)
68-
69-
70-
# current_nlq_chain = service.get_nlq_chain(question)
71-
# if question.use_rag:
72-
# examples = service.get_example(current_nlq_chain)
73-
# await response_websocket(websocket, session_id, "Examples:\n```json\n")
74-
# await response_websocket(websocket, session_id, str(examples))
75-
# await response_websocket(websocket, session_id, "\n```\n")
76-
# response = service.ask_with_response_stream(question, current_nlq_chain)
77-
# if os.getenv('SAGEMAKER_ENDPOINT_SQL', ''):
78-
# await response_sagemaker_sql(websocket, session_id, response, current_nlq_chain)
79-
# await response_websocket(websocket, session_id, "\n")
80-
# explain_response = service.explain_with_response_stream(current_nlq_chain)
81-
# await response_sagemaker_explain(websocket, session_id, explain_response)
82-
# else:
83-
# await response_bedrock(websocket, session_id, response, current_nlq_chain)
84-
#
85-
# if question.query_result:
86-
# final_sql_query_result = service.get_executed_result(current_nlq_chain)
87-
# await response_websocket(websocket, session_id, "\n\nQuery result: \n")
88-
# await response_websocket(websocket, session_id, final_sql_query_result)
89-
# await response_websocket(websocket, session_id, "\n")
9068
await response_websocket(websocket, session_id, ask_result.dict(), ContentEnum.END)
9169
except Exception:
9270
msg = traceback.format_exc()

application/api/service.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,13 @@ def ask(question: Question) -> Answer:
240240
split_strings = generated_sq.split("[generate]")
241241
generate_suggested_question_list = [s.strip() for s in split_strings if s.strip()]
242242

243-
# 连接数据库,执行SQL, 记录历史记录并展示
244243
if search_intent_flag:
245244
if normal_search_result.sql != "":
246245
current_nlq_chain.set_generated_sql(normal_search_result.sql)
247-
sql_search_result.sql = normal_search_result.sql
246+
sql_search_result.sql = normal_search_result.sql.strip()
248247
current_nlq_chain.set_generated_sql_response(normal_search_result.response)
249248
if explain_gen_process_flag:
250-
sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain()
249+
sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain().strip()
251250
else:
252251
sql_search_result.sql = "-1"
253252

@@ -484,10 +483,10 @@ async def ask_websocket(websocket: WebSocket, question: Question):
484483
if search_intent_flag:
485484
if normal_search_result.sql != "":
486485
current_nlq_chain.set_generated_sql(normal_search_result.sql)
487-
sql_search_result.sql = normal_search_result.sql
486+
sql_search_result.sql = normal_search_result.sql.strip()
488487
current_nlq_chain.set_generated_sql_response(normal_search_result.response)
489488
if explain_gen_process_flag:
490-
sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain()
489+
sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain().strip()
491490
else:
492491
sql_search_result.sql = "-1"
493492

0 commit comments

Comments
 (0)