Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 36 additions & 38 deletions application/pages/1_🌍_Generative_BI_Playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ def upvote_agent_clicked(question, comment, env_vars):
def clean_st_history(selected_profile):
st.session_state.messages[selected_profile] = []

def get_user_history(selected_profile):

def get_user_history(selected_profile: str):
"""
get user history for selected profile
:param selected_profile:
:return: history for selected profile list type
"""
history_list = st.session_state.messages[selected_profile]
history_query = []
for messages in history_list:
Expand All @@ -62,6 +68,7 @@ def get_user_history(selected_profile):
history_query.append(messages["content"])
return history_query


def do_visualize_results(nlq_chain, sql_result):
sql_query_result = sql_result
if sql_query_result is not None:
Expand Down Expand Up @@ -98,7 +105,7 @@ def do_visualize_results(nlq_chain, sql_result):
st.markdown('No visualization generated.')


def recurrent_display(messages, i, current_nlq_chain):
def recurrent_display(messages, i):
# hacking way of displaying messages, since the chat_message does not support multiple messages outside of "with" statement
current_role = messages[i]["role"]
message = messages[i]
Expand All @@ -115,12 +122,13 @@ def recurrent_display(messages, i, current_nlq_chain):
st.error(message["content"])
elif message["type"] == "sql":
with st.expander("The Generate SQL"):
st.code(message["content"], language="sql")
st.code(message["content"], language="sql")
return i


def normal_text_search_streamlit(search_box, model_type, database_profile, entity_slot, env_vars, selected_profile, use_rag,
model_provider=None):
def normal_text_search_streamlit(search_box, model_type, database_profile, entity_slot, env_vars, selected_profile,
use_rag,
model_provider=None):
entity_slot_retrieve = []
retrieve_result = []
response = ""
Expand Down Expand Up @@ -154,7 +162,7 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
with st.status("Performing QA retrieval...") as status_text:
if use_rag:
retrieve_result = get_retrieve_opensearch(env_vars, search_box, "query",
selected_profile, 3, 0.5)
selected_profile, 3, 0.5)
examples = []
for example in retrieve_result:
examples.append({'Score': example['_score'],
Expand All @@ -167,14 +175,14 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit

with st.status("Generating SQL... ") as status_text:
response = text_to_sql(database_profile['tables_info'],
database_profile['hints'],
database_profile['prompt_map'],
search_box,
model_id=model_type,
sql_examples=retrieve_result,
ner_example=entity_slot_retrieve,
dialect=database_profile['db_type'],
model_provider=model_provider)
database_profile['hints'],
database_profile['prompt_map'],
search_box,
model_id=model_type,
sql_examples=retrieve_result,
ner_example=entity_slot_retrieve,
dialect=database_profile['db_type'],
model_provider=model_provider)

sql = get_generated_sql(response)

Expand All @@ -193,9 +201,8 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
pass

status_text.update(
label=f"Generating SQL Done",
state="complete", expanded=True)

label=f"Generating SQL Done",
state="complete", expanded=True)

search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve,
retrieve_result=retrieve_result, response=response, sql="")
Expand All @@ -207,10 +214,10 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit
logger.error(e)
return search_result


def main():
load_dotenv()

# load config.json as dictionary
with open(os.path.join(os.getcwd(), 'config_files', '1_config.json')) as f:
env_vars = json.load(f)
opensearch_config = env_vars['data_sources']['shopping_guide']['opensearch']
Expand Down Expand Up @@ -243,12 +250,6 @@ def main():
if 'selected_sample' not in st.session_state:
st.session_state['selected_sample'] = ''

if 'dataframe' not in st.session_state:
st.session_state['dataframe'] = pd.DataFrame({
'column1': ['A', 'B', 'C'],
'column2': [1, 2, 3]
})

if 'current_profile' not in st.session_state:
st.session_state['current_profile'] = ''

Expand Down Expand Up @@ -336,7 +337,7 @@ def main():
# if i - 1 < new_index:
# continue
with st.chat_message(st.session_state.messages[selected_profile][i]["role"]):
new_index = recurrent_display(st.session_state.messages[selected_profile], i, current_nlq_chain)
new_index = recurrent_display(st.session_state.messages[selected_profile], i)

text_placeholder = "Type your query here..."

Expand All @@ -360,9 +361,6 @@ def main():
{"role": "user", "content": search_box, "type": "text"})
st.markdown(current_nlq_chain.get_question())
with st.chat_message("assistant"):
# retrieve_result = []
# entity_slot_retrieve = []
# deep_dive_sql_result = []
filter_deep_dive_sql_result = []
entity_slot = []
normal_search_result = None
Expand All @@ -384,8 +382,7 @@ def main():
prompt_map[key] = prompt_map_dict[key]
ProfileManagement.update_table_prompt_map(selected_profile, prompt_map)


# 多轮对话,query改写
# Multiple rounds of dialogue, query rewriting
user_query_history = get_user_history(selected_profile)
if len(user_query_history) > 0:
with st.status("Query Context Understanding") as status_text:
Expand All @@ -395,7 +392,7 @@ def main():
logger.info("The Origin query is {query} query rewrite is {new_query}".format(query=search_box,
new_query=new_search_box))
search_box = new_search_box
st.write(search_box)
st.write(search_box)
status_text.update(label=f"Query Context Rewrite Completed", state="complete", expanded=False)
intent_response = {
"intent": "normal_search",
Expand Down Expand Up @@ -437,10 +434,10 @@ def main():
elif search_intent_flag:
# 执行普通的查询,并可视化结果
normal_search_result = normal_text_search_streamlit(search_box, model_type,
database_profile,
entity_slot, env_vars,
selected_profile,
explain_gen_process_flag, use_rag_flag)
database_profile,
entity_slot, env_vars,
selected_profile,
explain_gen_process_flag, use_rag_flag)
elif knowledge_search_flag:
with st.spinner('Performing knowledge search...'):
response = knowledge_search(search_box=search_box, model_id=model_type,
Expand All @@ -459,8 +456,8 @@ def main():
agent_examples = []
for example in agent_cot_retrieve:
agent_examples.append({'Score': example['_score'],
'Question': example['_source']['query'],
'Answer': example['_source']['comment'].strip()})
'Question': example['_source']['query'],
'Answer': example['_source']['comment'].strip()})
st.write(agent_examples)
with st.expander(f'Agent Task : {len(agent_cot_task_result)}'):
st.write(agent_cot_task_result)
Expand Down Expand Up @@ -505,7 +502,8 @@ def main():
with st.expander("The SQL Error Info"):
st.markdown(search_intent_result["error_info"])
else:
if search_intent_result["data"] is not None and len(search_intent_result["data"]) > 0 and data_with_analyse:
if search_intent_result["data"] is not None and len(
search_intent_result["data"]) > 0 and data_with_analyse:
with st.spinner('Generating data summarize...'):
search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box,
search_intent_result["data"].to_json(
Expand Down
41 changes: 40 additions & 1 deletion application/utils/env_var.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import json
import os
import boto3
from botocore.exceptions import ClientError
from dotenv import load_dotenv

load_dotenv()
Expand All @@ -18,4 +21,40 @@
AOS_HOST = os.getenv('AOS_HOST')
AOS_PORT = os.getenv('AOS_PORT')
AOS_USER = os.getenv('AOS_USER')
AOS_PASSWORD = os.getenv('AOS_PASSWORD')
AOS_PASSWORD = os.getenv('AOS_PASSWORD')

AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION')

OPENSEARCH_TYPE = os.getenv('OPENSEARCH_TYPE')


def get_opensearch_parameter():
try:
session = boto3.session.Session()
sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION)
master_user = sm_client.get_secret_value(SecretId='opensearch-host-url')['SecretString']
data = json.loads(master_user)
es_host_name = data.get('host')
# cluster endpoint, for example: my-test-domain.us-east-1.es.amazonaws.com/
host = es_host_name + '/' if es_host_name[-1] != '/' else es_host_name
host = host[8:-1]

sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION)
master_user = sm_client.get_secret_value(SecretId='opensearch-master-user')['SecretString']
data = json.loads(master_user)
username = data.get('username')
password = data.get('password')
port = 443
return host, port, username, password
except ClientError as e:
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e


if OPENSEARCH_TYPE == "service":
opensearch_host, opensearch_port, opensearch_username, opensearch_password = get_opensearch_parameter()
AOS_HOST = opensearch_host
AOS_PORT = opensearch_port
AOS_USER = opensearch_username
AOS_PASSWORD = opensearch_password