Skip to content

Commit e972678

Browse files
authored
Merge pull request #102 from aws-samples/spy_dev
add secretsmanager get SecretString
2 parents 24b7582 + 5bc3f23 commit e972678

File tree

2 files changed

+76
-39
lines changed

2 files changed

+76
-39
lines changed

application/pages/1_🌍_Generative_BI_Playground.py

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ def upvote_agent_clicked(question, comment, env_vars):
5353
def clean_st_history(selected_profile):
5454
st.session_state.messages[selected_profile] = []
5555

56-
def get_user_history(selected_profile):
56+
57+
def get_user_history(selected_profile: str):
58+
"""
59+
get user history for selected profile
60+
:param selected_profile:
61+
:return: history for selected profile list type
62+
"""
5763
history_list = st.session_state.messages[selected_profile]
5864
history_query = []
5965
for messages in history_list:
@@ -62,6 +68,7 @@ def get_user_history(selected_profile):
6268
history_query.append(messages["content"])
6369
return history_query
6470

71+
6572
def do_visualize_results(nlq_chain, sql_result):
6673
sql_query_result = sql_result
6774
if sql_query_result is not None:
@@ -98,7 +105,7 @@ def do_visualize_results(nlq_chain, sql_result):
98105
st.markdown('No visualization generated.')
99106

100107

101-
def recurrent_display(messages, i, current_nlq_chain):
108+
def recurrent_display(messages, i):
102109
# hacking way of displaying messages, since the chat_message does not support multiple messages outside of "with" statement
103110
current_role = messages[i]["role"]
104111
message = messages[i]
@@ -115,12 +122,13 @@ def recurrent_display(messages, i, current_nlq_chain):
115122
st.error(message["content"])
116123
elif message["type"] == "sql":
117124
with st.expander("The Generate SQL"):
118-
st.code(message["content"], language="sql")
125+
st.code(message["content"], language="sql")
119126
return i
120127

121128

122-
def normal_text_search_streamlit(search_box, model_type, database_profile, entity_slot, env_vars, selected_profile, use_rag,
123-
model_provider=None):
129+
def normal_text_search_streamlit(search_box, model_type, database_profile, entity_slot, env_vars, selected_profile,
130+
use_rag,
131+
model_provider=None):
124132
entity_slot_retrieve = []
125133
retrieve_result = []
126134
response = ""
@@ -154,7 +162,7 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
154162
with st.status("Performing QA retrieval...") as status_text:
155163
if use_rag:
156164
retrieve_result = get_retrieve_opensearch(env_vars, search_box, "query",
157-
selected_profile, 3, 0.5)
165+
selected_profile, 3, 0.5)
158166
examples = []
159167
for example in retrieve_result:
160168
examples.append({'Score': example['_score'],
@@ -167,14 +175,14 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
167175

168176
with st.status("Generating SQL... ") as status_text:
169177
response = text_to_sql(database_profile['tables_info'],
170-
database_profile['hints'],
171-
database_profile['prompt_map'],
172-
search_box,
173-
model_id=model_type,
174-
sql_examples=retrieve_result,
175-
ner_example=entity_slot_retrieve,
176-
dialect=database_profile['db_type'],
177-
model_provider=model_provider)
178+
database_profile['hints'],
179+
database_profile['prompt_map'],
180+
search_box,
181+
model_id=model_type,
182+
sql_examples=retrieve_result,
183+
ner_example=entity_slot_retrieve,
184+
dialect=database_profile['db_type'],
185+
model_provider=model_provider)
178186

179187
sql = get_generated_sql(response)
180188

@@ -193,9 +201,8 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
193201
pass
194202

195203
status_text.update(
196-
label=f"Generating SQL Done",
197-
state="complete", expanded=True)
198-
204+
label=f"Generating SQL Done",
205+
state="complete", expanded=True)
199206

200207
search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve,
201208
retrieve_result=retrieve_result, response=response, sql="")
@@ -207,10 +214,10 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
207214
logger.error(e)
208215
return search_result
209216

217+
210218
def main():
211219
load_dotenv()
212220

213-
# load config.json as dictionary
214221
with open(os.path.join(os.getcwd(), 'config_files', '1_config.json')) as f:
215222
env_vars = json.load(f)
216223
opensearch_config = env_vars['data_sources']['shopping_guide']['opensearch']
@@ -243,12 +250,6 @@ def main():
243250
if 'selected_sample' not in st.session_state:
244251
st.session_state['selected_sample'] = ''
245252

246-
if 'dataframe' not in st.session_state:
247-
st.session_state['dataframe'] = pd.DataFrame({
248-
'column1': ['A', 'B', 'C'],
249-
'column2': [1, 2, 3]
250-
})
251-
252253
if 'current_profile' not in st.session_state:
253254
st.session_state['current_profile'] = ''
254255

@@ -336,7 +337,7 @@ def main():
336337
# if i - 1 < new_index:
337338
# continue
338339
with st.chat_message(st.session_state.messages[selected_profile][i]["role"]):
339-
new_index = recurrent_display(st.session_state.messages[selected_profile], i, current_nlq_chain)
340+
new_index = recurrent_display(st.session_state.messages[selected_profile], i)
340341

341342
text_placeholder = "Type your query here..."
342343

@@ -360,9 +361,6 @@ def main():
360361
{"role": "user", "content": search_box, "type": "text"})
361362
st.markdown(current_nlq_chain.get_question())
362363
with st.chat_message("assistant"):
363-
# retrieve_result = []
364-
# entity_slot_retrieve = []
365-
# deep_dive_sql_result = []
366364
filter_deep_dive_sql_result = []
367365
entity_slot = []
368366
normal_search_result = None
@@ -384,8 +382,7 @@ def main():
384382
prompt_map[key] = prompt_map_dict[key]
385383
ProfileManagement.update_table_prompt_map(selected_profile, prompt_map)
386384

387-
388-
# 多轮对话,query改写
385+
# Multiple rounds of dialogue, query rewriting
389386
user_query_history = get_user_history(selected_profile)
390387
if len(user_query_history) > 0:
391388
with st.status("Query Context Understanding") as status_text:
@@ -395,7 +392,7 @@ def main():
395392
logger.info("The Origin query is {query} query rewrite is {new_query}".format(query=search_box,
396393
new_query=new_search_box))
397394
search_box = new_search_box
398-
st.write(search_box)
395+
st.write(search_box)
399396
status_text.update(label=f"Query Context Rewrite Completed", state="complete", expanded=False)
400397
intent_response = {
401398
"intent": "normal_search",
@@ -437,10 +434,10 @@ def main():
437434
elif search_intent_flag:
438435
# 执行普通的查询,并可视化结果
439436
normal_search_result = normal_text_search_streamlit(search_box, model_type,
440-
database_profile,
441-
entity_slot, env_vars,
442-
selected_profile,
443-
explain_gen_process_flag, use_rag_flag)
437+
database_profile,
438+
entity_slot, env_vars,
439+
selected_profile,
440+
explain_gen_process_flag, use_rag_flag)
444441
elif knowledge_search_flag:
445442
with st.spinner('Performing knowledge search...'):
446443
response = knowledge_search(search_box=search_box, model_id=model_type,
@@ -459,8 +456,8 @@ def main():
459456
agent_examples = []
460457
for example in agent_cot_retrieve:
461458
agent_examples.append({'Score': example['_score'],
462-
'Question': example['_source']['query'],
463-
'Answer': example['_source']['comment'].strip()})
459+
'Question': example['_source']['query'],
460+
'Answer': example['_source']['comment'].strip()})
464461
st.write(agent_examples)
465462
with st.expander(f'Agent Task : {len(agent_cot_task_result)}'):
466463
st.write(agent_cot_task_result)
@@ -505,7 +502,8 @@ def main():
505502
with st.expander("The SQL Error Info"):
506503
st.markdown(search_intent_result["error_info"])
507504
else:
508-
if search_intent_result["data"] is not None and len(search_intent_result["data"]) > 0 and data_with_analyse:
505+
if search_intent_result["data"] is not None and len(
506+
search_intent_result["data"]) > 0 and data_with_analyse:
509507
with st.spinner('Generating data summarize...'):
510508
search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box,
511509
search_intent_result["data"].to_json(

application/utils/env_var.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import json
12
import os
3+
import boto3
4+
from botocore.exceptions import ClientError
25
from dotenv import load_dotenv
36

47
load_dotenv()
@@ -18,4 +21,40 @@
1821
AOS_HOST = os.getenv('AOS_HOST')
1922
AOS_PORT = os.getenv('AOS_PORT')
2023
AOS_USER = os.getenv('AOS_USER')
21-
AOS_PASSWORD = os.getenv('AOS_PASSWORD')
24+
AOS_PASSWORD = os.getenv('AOS_PASSWORD')
25+
26+
AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION')
27+
28+
OPENSEARCH_TYPE = os.getenv('OPENSEARCH_TYPE')
29+
30+
31+
def get_opensearch_parameter():
32+
try:
33+
session = boto3.session.Session()
34+
sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION)
35+
master_user = sm_client.get_secret_value(SecretId='opensearch-host-url')['SecretString']
36+
data = json.loads(master_user)
37+
es_host_name = data.get('host')
38+
# cluster endpoint, for example: my-test-domain.us-east-1.es.amazonaws.com/
39+
host = es_host_name + '/' if es_host_name[-1] != '/' else es_host_name
40+
host = host[8:-1]
41+
42+
sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION)
43+
master_user = sm_client.get_secret_value(SecretId='opensearch-master-user')['SecretString']
44+
data = json.loads(master_user)
45+
username = data.get('username')
46+
password = data.get('password')
47+
port = 443
48+
return host, port, username, password
49+
except ClientError as e:
50+
# For a list of exceptions thrown, see
51+
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
52+
raise e
53+
54+
55+
if OPENSEARCH_TYPE == "service":
56+
opensearch_host, opensearch_port, opensearch_username, opensearch_password = get_opensearch_parameter()
57+
AOS_HOST = opensearch_host
58+
AOS_PORT = opensearch_port
59+
AOS_USER = opensearch_username
60+
AOS_PASSWORD = opensearch_password

0 commit comments

Comments
 (0)