diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index 0ae9321..592bd94 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -53,7 +53,13 @@ def upvote_agent_clicked(question, comment, env_vars): def clean_st_history(selected_profile): st.session_state.messages[selected_profile] = [] -def get_user_history(selected_profile): + +def get_user_history(selected_profile: str): + """ + get user history for selected profile + :param selected_profile: + :return: history for selected profile list type + """ history_list = st.session_state.messages[selected_profile] history_query = [] for messages in history_list: @@ -62,6 +68,7 @@ def get_user_history(selected_profile): history_query.append(messages["content"]) return history_query + def do_visualize_results(nlq_chain, sql_result): sql_query_result = sql_result if sql_query_result is not None: @@ -98,7 +105,7 @@ def do_visualize_results(nlq_chain, sql_result): st.markdown('No visualization generated.') -def recurrent_display(messages, i, current_nlq_chain): +def recurrent_display(messages, i): # hacking way of displaying messages, since the chat_message does not support multiple messages outside of "with" statement current_role = messages[i]["role"] message = messages[i] @@ -115,12 +122,13 @@ def recurrent_display(messages, i, current_nlq_chain): st.error(message["content"]) elif message["type"] == "sql": with st.expander("The Generate SQL"): - st.code(message["content"], language="sql") + st.code(message["content"], language="sql") return i -def normal_text_search_streamlit(search_box, model_type, database_profile, entity_slot, env_vars, selected_profile, use_rag, - model_provider=None): +def normal_text_search_streamlit(search_box, model_type, database_profile, entity_slot, env_vars, selected_profile, + use_rag, + model_provider=None): entity_slot_retrieve = [] retrieve_result = [] response = "" @@ -154,7 +162,7 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit with st.status("Performing QA retrieval...") as status_text: if use_rag: retrieve_result = get_retrieve_opensearch(env_vars, search_box, "query", - selected_profile, 3, 0.5) + selected_profile, 3, 0.5) examples = [] for example in retrieve_result: examples.append({'Score': example['_score'], @@ -167,14 +175,14 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit with st.status("Generating SQL... ") as status_text: response = text_to_sql(database_profile['tables_info'], - database_profile['hints'], - database_profile['prompt_map'], - search_box, - model_id=model_type, - sql_examples=retrieve_result, - ner_example=entity_slot_retrieve, - dialect=database_profile['db_type'], - model_provider=model_provider) + database_profile['hints'], + database_profile['prompt_map'], + search_box, + model_id=model_type, + sql_examples=retrieve_result, + ner_example=entity_slot_retrieve, + dialect=database_profile['db_type'], + model_provider=model_provider) sql = get_generated_sql(response) @@ -193,9 +201,8 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit pass status_text.update( - label=f"Generating SQL Done", - state="complete", expanded=True) - + label=f"Generating SQL Done", + state="complete", expanded=True) search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve, retrieve_result=retrieve_result, response=response, sql="") @@ -207,10 +214,10 @@ def normal_text_search_streamlit(search_box, model_type, database_profile, entit logger.error(e) return search_result + def main(): load_dotenv() - # load config.json as dictionary with open(os.path.join(os.getcwd(), 'config_files', '1_config.json')) as f: env_vars = json.load(f) opensearch_config = env_vars['data_sources']['shopping_guide']['opensearch'] @@ -243,12 +250,6 @@ def main(): if 'selected_sample' not in st.session_state: st.session_state['selected_sample'] = '' - if 'dataframe' not in st.session_state: - st.session_state['dataframe'] = pd.DataFrame({ - 'column1': ['A', 'B', 'C'], - 'column2': [1, 2, 3] - }) - if 'current_profile' not in st.session_state: st.session_state['current_profile'] = '' @@ -336,7 +337,7 @@ def main(): # if i - 1 < new_index: # continue with st.chat_message(st.session_state.messages[selected_profile][i]["role"]): - new_index = recurrent_display(st.session_state.messages[selected_profile], i, current_nlq_chain) + new_index = recurrent_display(st.session_state.messages[selected_profile], i) text_placeholder = "Type your query here..." @@ -360,9 +361,6 @@ def main(): {"role": "user", "content": search_box, "type": "text"}) st.markdown(current_nlq_chain.get_question()) with st.chat_message("assistant"): - # retrieve_result = [] - # entity_slot_retrieve = [] - # deep_dive_sql_result = [] filter_deep_dive_sql_result = [] entity_slot = [] normal_search_result = None @@ -384,8 +382,7 @@ def main(): prompt_map[key] = prompt_map_dict[key] ProfileManagement.update_table_prompt_map(selected_profile, prompt_map) - - # 多轮对话,query改写 + # Multiple rounds of dialogue, query rewriting user_query_history = get_user_history(selected_profile) if len(user_query_history) > 0: with st.status("Query Context Understanding") as status_text: @@ -395,7 +392,7 @@ def main(): logger.info("The Origin query is {query} query rewrite is {new_query}".format(query=search_box, new_query=new_search_box)) search_box = new_search_box - st.write(search_box) + st.write(search_box) status_text.update(label=f"Query Context Rewrite Completed", state="complete", expanded=False) intent_response = { "intent": "normal_search", @@ -437,10 +434,10 @@ def main(): elif search_intent_flag: # 执行普通的查询,并可视化结果 normal_search_result = normal_text_search_streamlit(search_box, model_type, - database_profile, - entity_slot, env_vars, - selected_profile, - explain_gen_process_flag, use_rag_flag) + database_profile, + entity_slot, env_vars, + selected_profile, + explain_gen_process_flag, use_rag_flag) elif knowledge_search_flag: with st.spinner('Performing knowledge search...'): response = knowledge_search(search_box=search_box, model_id=model_type, @@ -459,8 +456,8 @@ def main(): agent_examples = [] for example in agent_cot_retrieve: agent_examples.append({'Score': example['_score'], - 'Question': example['_source']['query'], - 'Answer': example['_source']['comment'].strip()}) + 'Question': example['_source']['query'], + 'Answer': example['_source']['comment'].strip()}) st.write(agent_examples) with st.expander(f'Agent Task : {len(agent_cot_task_result)}'): st.write(agent_cot_task_result) @@ -505,7 +502,8 @@ def main(): with st.expander("The SQL Error Info"): st.markdown(search_intent_result["error_info"]) else: - if search_intent_result["data"] is not None and len(search_intent_result["data"]) > 0 and data_with_analyse: + if search_intent_result["data"] is not None and len( + search_intent_result["data"]) > 0 and data_with_analyse: with st.spinner('Generating data summarize...'): search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, search_intent_result["data"].to_json( diff --git a/application/utils/env_var.py b/application/utils/env_var.py index 1f669d5..5785587 100644 --- a/application/utils/env_var.py +++ b/application/utils/env_var.py @@ -1,4 +1,7 @@ +import json import os +import boto3 +from botocore.exceptions import ClientError from dotenv import load_dotenv load_dotenv() @@ -18,4 +21,40 @@ AOS_HOST = os.getenv('AOS_HOST') AOS_PORT = os.getenv('AOS_PORT') AOS_USER = os.getenv('AOS_USER') -AOS_PASSWORD = os.getenv('AOS_PASSWORD') \ No newline at end of file +AOS_PASSWORD = os.getenv('AOS_PASSWORD') + +AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION') + +OPENSEARCH_TYPE = os.getenv('OPENSEARCH_TYPE') + + +def get_opensearch_parameter(): + try: + session = boto3.session.Session() + sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION) + master_user = sm_client.get_secret_value(SecretId='opensearch-host-url')['SecretString'] + data = json.loads(master_user) + es_host_name = data.get('host') + # cluster endpoint, for example: my-test-domain.us-east-1.es.amazonaws.com/ + host = es_host_name + '/' if es_host_name[-1] != '/' else es_host_name + host = host[8:-1] + + sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION) + master_user = sm_client.get_secret_value(SecretId='opensearch-master-user')['SecretString'] + data = json.loads(master_user) + username = data.get('username') + password = data.get('password') + port = 443 + return host, port, username, password + except ClientError as e: + # For a list of exceptions thrown, see + # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + raise e + + +if OPENSEARCH_TYPE == "service": + opensearch_host, opensearch_port, opensearch_username, opensearch_password = get_opensearch_parameter() + AOS_HOST = opensearch_host + AOS_PORT = opensearch_port + AOS_USER = opensearch_username + AOS_PASSWORD = opensearch_password