Skip to content

Commit 583e77b

Browse files
authored
Merge pull request #59 from aws-samples/spy_dev
add user_feedback_upvote
2 parents 92268fd + 2d1e34d commit 583e77b

File tree

4 files changed

+64
-22
lines changed

4 files changed

+64
-22
lines changed

application/api/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from nlq.business.profile import ProfileManagement
88
from .enum import ContentEnum, ErrorEnum
9-
from .schemas import Question, QuestionSocket, Answer, Option, CustomQuestion
9+
from .schemas import Question, QuestionSocket, Answer, Option, CustomQuestion, Upvote
1010
from . import service
1111
from nlq.business.nlq_chain import NLQChain
1212
from dotenv import load_dotenv
@@ -38,6 +38,13 @@ def ask(question: Question):
3838
return service.ask(question)
3939

4040

41+
@router.post("/upvote")
42+
def upvote(upvote_input: Upvote):
43+
upvote_res = service.user_feedback_upvote(upvote_input.data_profiles,upvote_input.query,
44+
upvote_input.query_intent, upvote_input.query_answer_list)
45+
return upvote_res
46+
47+
4148
@router.websocket("/ws")
4249
async def websocket_endpoint(websocket: WebSocket):
4350
await websocket.accept()

application/api/schemas.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ class Example(BaseModel):
3535
# sql_query_result: list[Any]
3636

3737

38+
class QueryEntity(BaseModel):
39+
query: str
40+
sql: str
41+
42+
43+
class Upvote(BaseModel):
44+
data_profiles: str
45+
query: str
46+
query_intent: str
47+
query_answer_list: list[QueryEntity]
48+
49+
3850
class Option(BaseModel):
3951
data_profiles: list[str]
4052
bedrock_model_ids: list[str]
@@ -57,6 +69,7 @@ class KnowledgeSearchResult(BaseModel):
5769

5870

5971
class AgentSearchResult(BaseModel):
72+
sub_search_task: list[str]
6073
agent_sql_search_result: list[SQLSearchResult]
6174
agent_summary: str
6275

application/api/service.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from nlq.business.connection import ConnectionManagement
88
from nlq.business.nlq_chain import NLQChain
99
from nlq.business.profile import ProfileManagement
10+
from nlq.business.vector_store import VectorStore
1011
from utils.apis import get_sql_result_tool
1112
from utils.database import get_db_url_dialect
1213
from utils.llm import text_to_sql, get_query_intent, create_vector_embedding_with_sagemaker, \
@@ -182,7 +183,7 @@ def ask(question: Question) -> Answer:
182183
sql_gen_process="",
183184
data_analyse="")
184185

185-
agent_search_response = AgentSearchResult(agent_summary="", agent_sql_search_result=[])
186+
agent_search_response = AgentSearchResult(agent_summary="", agent_sql_search_result=[], sub_search_task=[])
186187

187188
knowledge_search_result = KnowledgeSearchResult(knowledge_response="")
188189

@@ -221,7 +222,8 @@ def ask(question: Question) -> Answer:
221222

222223
if reject_intent_flag:
223224
answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result,
224-
sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[])
225+
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
226+
suggested_question=[])
225227
return answer
226228
elif search_intent_flag:
227229
normal_search_result = normal_text_search(search_box, model_type,
@@ -232,8 +234,10 @@ def ask(question: Question) -> Answer:
232234
response = knowledge_search(search_box=search_box, model_id=model_type)
233235

234236
knowledge_search_result.knowledge_response = response
235-
answer = Answer(query=search_box, query_intent="knowledge_search", knowledge_search_result=knowledge_search_result,
236-
sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[])
237+
answer = Answer(query=search_box, query_intent="knowledge_search",
238+
knowledge_search_result=knowledge_search_result,
239+
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
240+
suggested_question=[])
237241
return answer
238242

239243
else:
@@ -269,38 +273,60 @@ def ask(question: Question) -> Answer:
269273
search_intent_result["data"].to_json(
270274
orient='records', force_ascii=False), "query")
271275
sql_search_result.data_analyse = search_intent_analyse_result
272-
sql_search_result.sql_data = [list(search_intent_result["data"].columns)] +search_intent_result["data"].values.tolist()
276+
sql_search_result.sql_data = [list(search_intent_result["data"].columns)] + search_intent_result[
277+
"data"].values.tolist()
273278

274279
answer = Answer(query=search_box, query_intent="normal_search", knowledge_search_result=knowledge_search_result,
275-
sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[])
280+
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
281+
suggested_question=[])
276282
return answer
277283
else:
284+
sub_search_task = []
278285
for i in range(len(agent_search_result)):
279286
each_task_res = get_sql_result_tool(database_profile, agent_search_result[i]["sql"])
280287
if each_task_res["status_code"] == 200 and len(each_task_res["data"]) > 0:
281288
agent_search_result[i]["data_result"] = each_task_res["data"].to_json(
282289
orient='records')
283290
filter_deep_dive_sql_result.append(agent_search_result[i])
284291
each_task_sql_res = [list(each_task_res["data"].columns)] + each_task_res["data"].values.tolist()
285-
each_task_sql_search_result = SQLSearchResult(query=agent_search_result[i]["query"], sql_data=each_task_sql_res,
292+
each_task_sql_search_result = SQLSearchResult(query=agent_search_result[i]["query"],
293+
sql_data=each_task_sql_res,
286294
sql=each_task_res["sql"], data_show_type="table",
287295
sql_gen_process="",
288296
data_analyse="")
289297
agent_sql_search_result.append(each_task_sql_search_result)
290-
298+
sub_search_task.append(agent_search_result[i]["query"])
291299
agent_data_analyse_result = data_analyse_tool(model_type, search_box,
292300
json.dumps(filter_deep_dive_sql_result, ensure_ascii=False),
293301
"agent")
294302
logger.info("agent_data_analyse_result")
295303
logger.info(agent_data_analyse_result)
296304
agent_search_response.agent_summary = agent_data_analyse_result
297305
agent_search_response.agent_sql_search_result = agent_sql_search_result
306+
agent_search_response.sub_search_task = sub_search_task
298307

299308
answer = Answer(query=search_box, query_intent="agent_search", knowledge_search_result=knowledge_search_result,
300-
sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[])
309+
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
310+
suggested_question=[])
301311
return answer
302312

303313

314+
def user_feedback_upvote(data_profiles: str, query: str, query_intent: str, query_answer_list):
315+
try:
316+
if query_intent == "normal_search":
317+
if len(query_answer_list) > 0:
318+
VectorStore.add_sample(data_profiles, query_answer_list[0].query, query_answer_list[0].sql)
319+
elif query_intent == "agent_search":
320+
query_list = []
321+
for each in query_answer_list:
322+
query_list.append(each.query)
323+
VectorStore.add_sample(data_profiles, each.query, each.sql)
324+
VectorStore.add_agent_cot_sample(data_profiles, query, "\n".join(query_list))
325+
return True
326+
except Exception as e:
327+
return False
328+
329+
304330
def get_nlq_chain(question: Question) -> NLQChain:
305331
logger.debug(question)
306332
verify_parameters(question)

application/pages/1_🌍_Generative_BI_Playground.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def do_visualize_results(nlq_chain, sql_result):
6262
# hacky way to get around the issue of selectbox not updating when the options change
6363
chart_type = visualize_config_columns[0].selectbox('Choose the chart type',
6464
['Table', 'Bar', 'Line', 'Pie'],
65-
on_change=nlq_chain.set_visualization_config_change,
66-
key=random.randint(0, 10000)
65+
on_change=nlq_chain.set_visualization_config_change
6766
)
6867
if chart_type != 'Table':
6968
x_column = visualize_config_columns[1].selectbox(f'Choose x-axis column', available_columns,
@@ -93,7 +92,7 @@ def recurrent_display(messages, i, current_nlq_chain):
9392
message = messages[i]
9493
if message["type"] == "pandas":
9594
if isinstance(message["content"], pd.DataFrame):
96-
do_visualize_results(current_nlq_chain, message["content"])
95+
st.dataframe(message["content"], hide_index=True)
9796
elif isinstance(message["content"], list):
9897
for each_content in message["content"]:
9998
st.write(each_content["query"])
@@ -102,13 +101,10 @@ def recurrent_display(messages, i, current_nlq_chain):
102101
st.markdown(message["content"])
103102
elif message["type"] == "error":
104103
st.error(message["content"])
105-
if i + 1 < len(messages):
106-
if current_role != messages[i + 1]["role"]:
107-
return i
108-
else:
109-
return recurrent_display(messages, i + 1, current_nlq_chain)
110-
else:
111-
return i
104+
elif message["type"] == "sql":
105+
with st.expander("The Generate SQL"):
106+
st.code(message["content"], language="sql")
107+
return i
112108

113109

114110
def main():
@@ -221,8 +217,8 @@ def main():
221217
for i in range(len(st.session_state.messages[selected_profile])):
222218
print('!!!!!')
223219
print(i, new_index)
224-
if i - 1 < new_index:
225-
continue
220+
# if i - 1 < new_index:
221+
# continue
226222
with st.chat_message(st.session_state.messages[selected_profile][i]["role"]):
227223
new_index = recurrent_display(st.session_state.messages[selected_profile], i, current_nlq_chain)
228224

0 commit comments

Comments
 (0)