Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
63786c8
change convert_timestamps_to_str
Aug 21, 2024
018d942
Merge pull request #267 from aws-samples/v1.6.0_spy
supinyu Aug 21, 2024
04346ee
add error info show
Aug 22, 2024
0b595e8
add error info show
Aug 22, 2024
cdcbe96
Merge pull request #270 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
28a6232
add error info show
Aug 22, 2024
95cba4d
add error info show
Aug 22, 2024
d50c93c
add error info show
Aug 22, 2024
56d13c8
add error info show
Aug 22, 2024
c1995e6
Merge pull request #271 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
82644df
fix dup key
Aug 22, 2024
7df023b
Merge pull request #272 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
e771d39
add profile select
Aug 22, 2024
135804b
add profile select
Aug 22, 2024
b7db217
add profile select
Aug 22, 2024
99cce8a
add update_profile
Aug 22, 2024
70b6ed9
Merge pull request #273 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
2fcf8b9
add get_generated_sql_explain
Aug 22, 2024
1fb959c
Merge pull request #274 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
f4dbd94
add update_profile
Aug 22, 2024
ec2affa
Merge pull request #275 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
8ead976
add update_profile
Aug 22, 2024
9b4153d
add update_profile
Aug 22, 2024
5760315
Merge pull request #276 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
36becba
add st.session_state.ner_refresh_view
Aug 22, 2024
576aeca
add st.session_state.ner_refresh_view
Aug 22, 2024
f921b49
Merge pull request #277 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
58108d3
fix update_profile problem
Aug 22, 2024
2a9dabf
Merge pull request #278 from aws-samples/v1.6.0_spy
supinyu Aug 22, 2024
233bea4
Merge branch 'refs/heads/v1.6.0' into v1.7.0_dev
Aug 22, 2024
6ee71d8
fix update_profile problem
Aug 22, 2024
1bd27f6
Merge branch 'refs/heads/v1.7.0' into v1.7.0_dev
Aug 22, 2024
6da93d2
fix DataSourceFactory
Aug 22, 2024
6f53313
add db_type hive
Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 122 additions & 78 deletions application/nlq/core/state_machine.py

Large diffs are not rendered by default.

29 changes: 20 additions & 9 deletions application/nlq/data_access/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,25 @@ class RelationDatabase():

@classmethod
def get_db_url(cls, db_type, user, password, host, port, db_name):
db_url = db.engine.URL.create(
drivername=cls.db_mapping[db_type],
username=user,
password=password,
host=host,
port=port,
database=db_name
)
if db_type == "hive":
db_url = db.engine.URL.create(
drivername=cls.db_mapping[db_type],
username=user,
password=password,
host=host,
port=port,
database=db_name,
query={'auth': 'LDAP'}
)
else:
db_url = db.engine.URL.create(
drivername=cls.db_mapping[db_type],
username=user,
password=password,
host=host,
port=port,
database=db_name
)
return db_url

@classmethod
Expand All @@ -53,7 +64,7 @@ def get_all_schema_names_by_connection(cls, connection: ConnectConfigEntity):
if db_type == 'postgresql':
schemas = [schema for schema in inspector.get_schema_names() if
schema not in ('pg_catalog', 'information_schema', 'public')]
elif db_type in ('redshift', 'mysql', 'starrocks', 'clickhouse'):
elif db_type in ('redshift', 'mysql', 'starrocks', 'clickhouse', 'hive'):
schemas = inspector.get_schema_names()
else:
raise ValueError("Unsupported database type")
Expand Down
77 changes: 65 additions & 12 deletions application/pages/1_🌍_Generative_BI_Playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import plotly.express as px
from dotenv import load_dotenv

import logging
from api.service import user_feedback_downvote
from nlq.business.connection import ConnectionManagement
from nlq.business.profile import ProfileManagement
Expand All @@ -13,9 +13,9 @@
from nlq.core.state_machine import QueryStateMachine
from utils.navigation import make_sidebar
from utils.env_var import opensearch_info
from utils.logging import getLogger

logger = getLogger()
logger = logging.getLogger(__name__)


def sample_question_clicked(sample):
"""Update the selected_sample variable with the text of the clicked button"""
Expand Down Expand Up @@ -148,17 +148,28 @@ def main():
# Title and Description
st.subheader('Generative BI Playground')

st.write('Current Username: ' + st.session_state['auth_username'])

demo_profile_suffix = '(demo)'
# Initialize or set up state variables

if "update_profile" not in st.session_state:
st.session_state.update_profile = False

if "profiles_list" not in st.session_state:
st.session_state["profiles_list"] = []

if 'profiles' not in st.session_state:
# get all user defined profiles with info (db_url, conn_name, tables_info, hints, search_samples)
all_profiles = ProfileManagement.get_all_profiles_with_info()
# all_profiles.update(demo_profile)
st.session_state['profiles'] = all_profiles
st.session_state["profiles_list"] = list(all_profiles.keys())
else:
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state['profiles'] = all_profiles
if st.session_state.update_profile:
logger.info("session_state update_profile get_all_profiles_with_info")
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state["profiles_list"] = list(all_profiles.keys())
st.session_state['profiles'] = all_profiles
st.session_state.update_profile = False

if "vision_change" not in st.session_state:
st.session_state["vision_change"] = False
Expand Down Expand Up @@ -321,7 +332,7 @@ def main():
search_box=search_box,
query_rewrite="",
session_id="",
user_id=st.session_state['auth_username'],
user_id="",
selected_profile=selected_profile,
database_profile=database_profile,
model_type=model_type,
Expand Down Expand Up @@ -392,21 +403,24 @@ def main():
state_machine.handle_sql_generation()
sql = state_machine.get_answer().sql_search_result.sql
st.code(sql, language="sql")
st.session_state.messages[selected_profile].append(
if not visualize_results_flag:
st.session_state.messages[selected_profile].append(
{"role": "assistant", "content": sql, "type": "sql"})
feedback = st.columns(2)
feedback[0].button('👍 Upvote (save as embedding for retrieval)', type='secondary',
key="upvote",
use_container_width=True,
on_click=upvote_clicked,
args=[search_box,
sql])
feedback[1].button('👎 Downvote', type='secondary', use_container_width=True,
key="downvote",
on_click=downvote_clicked,
args=[search_box, sql])
status_text.update(
label=f"Generating SQL Done",
state="complete", expanded=True)
if state_machine.context.gen_suggested_question_flag:
if state_machine.context.explain_gen_process_flag:
with st.status("Generating explanations...") as status_text:
st.markdown(state_machine.get_answer().sql_search_result.sql_gen_process)
status_text.update(
Expand All @@ -421,7 +435,39 @@ def main():
status_text.update(label=f"Intent Recognition Completed: This is a **{intent}** question",
state="complete", expanded=False)
elif state_machine.get_state() == QueryState.EXECUTE_QUERY:
state_machine.handle_execute_query()
with st.status("Execute SQL...") as status_text:
state_machine.handle_execute_query()
status_text.update(label=f"Execute SQL Done",
state="complete", expanded=False)
sql = state_machine.get_answer().sql_search_result.sql
if state_machine.use_auto_correction_flag:
with st.expander("The SQL Error Info"):
st.markdown(state_machine.first_sql_execute_info["error_info"])
with st.status("Generating SQL Again ... ") as status_text:
st.code(sql, language="sql")
st.session_state.messages[selected_profile].append(
{"role": "assistant", "content": sql, "type": "sql"})
feedback = st.columns(2)
feedback[0].button('👍 Upvote (save as embedding for retrieval)', type='secondary',
key="upvote_again",
use_container_width=True,
on_click=upvote_clicked,
args=[search_box,
sql])
feedback[1].button('👎 Downvote', type='secondary', use_container_width=True,
key="downcote_again",
on_click=downvote_clicked,
args=[search_box, sql])
status_text.update(
label=f"Generating SQL Done",
state="complete", expanded=True)
else:
st.session_state.messages[selected_profile].append(
{"role": "assistant", "content": sql, "type": "sql"})
if state_machine.get_answer().sql_search_result.sql_data is not None:
st.session_state.messages[selected_profile].append(
{"role": "assistant", "content": state_machine.get_answer().sql_search_result.sql_data, "type": "pandas"})

elif state_machine.get_state() == QueryState.ANALYZE_DATA:
with st.spinner('Generating data summarize...'):
state_machine.handle_analyze_data()
Expand Down Expand Up @@ -482,6 +528,12 @@ def main():
st.session_state.current_sql_result = \
state_machine.intent_search_result["sql_execute_result"]["data"]
do_visualize_results()
elif state_machine.get_state() == QueryState.ERROR:
with st.status("The Error Info Please Check") as status_text:
st.write(state_machine.error_log)
status_text.update(label=f"The Error Info Please Check",
state="error", expanded=False)

if processing_context.gen_suggested_question_flag:
if state_machine.search_intent_flag or state_machine.agent_intent_flag:
st.markdown('You might want to further ask:')
Expand All @@ -502,7 +554,8 @@ def main():
on_click=sample_question_clicked,
args=[gen_sq_list[2]])
else:
do_visualize_results()
if visualize_results_flag:
do_visualize_results()


if __name__ == '__main__':
Expand Down
22 changes: 22 additions & 0 deletions application/pages/3_🪙_Data_Profile_Management.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_table_name_by_config(_conn_config, schema_names, default_values):

def show_delete_profile(profile_name):
if st.button('Delete Profile'):
st.session_state.update_profile = True
ProfileManagement.delete_profile(profile_name)
st.success(f"{profile_name} deleted successfully!")
st.session_state.profile_page_mode = 'default'
Expand All @@ -64,12 +65,30 @@ def main():
st.set_page_config(page_title="Data Profile Management", )
make_sidebar()

if "update_profile" not in st.session_state:
st.session_state.update_profile = False

if 'profile_page_mode' not in st.session_state:
st.session_state['profile_page_mode'] = 'default'

if 'current_profile' not in st.session_state:
st.session_state['current_profile'] = ''

if "profiles_list" not in st.session_state:
st.session_state["profiles_list"] = []

if 'profiles' not in st.session_state:
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state['profiles'] = all_profiles
st.session_state["profiles_list"] = list(all_profiles.keys())

if st.session_state.update_profile:
logger.info("session_state update_profile get_all_profiles_with_info")
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state["profiles_list"] = list(all_profiles.keys())
st.session_state['profiles'] = all_profiles
st.session_state.update_profile = False

with st.sidebar:
st.title("Data Profile Management")
st.selectbox("My Data Profiles", get_all_profiles(),
Expand All @@ -94,6 +113,7 @@ def main():
comments = st.text_input("Comments")

if st.button('Create Profile', type='primary'):
st.session_state.update_profile = True
if not selected_tables:
st.error('Please select at least one table.')
return
Expand Down Expand Up @@ -156,6 +176,7 @@ def main():
column_value: $login_user.username""", disabled=not st_enable_rls, height=240)

if st.button('Update Profile', type='primary'):
st.session_state.update_profile = True
if not selected_tables:
st.error('Please select at least one table.')
return
Expand All @@ -173,6 +194,7 @@ def main():
st.cache_data.clear()

if st.button('Fetch table definition'):
st.session_state.update_profile = True
if not selected_tables:
st.error('Please select at least one table.')
with st.spinner('fetching...'):
Expand Down
23 changes: 21 additions & 2 deletions application/pages/4_🪙_Schema_Description_Management.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,32 @@ def main():
if 'current_profile' not in st.session_state:
st.session_state['current_profile'] = ''

if "update_profile" not in st.session_state:
st.session_state.update_profile = False

if "profiles_list" not in st.session_state:
st.session_state["profiles_list"] = []

if 'profiles' not in st.session_state:
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state['profiles'] = all_profiles
st.session_state["profiles_list"] = list(all_profiles.keys())

if st.session_state.update_profile:
logger.info("session_state update_profile get_all_profiles_with_info")
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state["profiles_list"] = list(all_profiles.keys())
st.session_state['profiles'] = all_profiles
st.session_state.update_profile = False

with st.sidebar:
st.title("Schema Management")
all_profiles_list = ProfileManagement.get_all_profiles()
all_profiles_list = st.session_state["profiles_list"]
if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list:
profile_index = all_profiles_list.index(st.session_state.current_profile)
current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index)
else:
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
current_profile = st.selectbox("My Data Profiles", all_profiles_list,
index=None,
placeholder="Please select data profile...", key='current_profile_name')

Expand Down Expand Up @@ -56,6 +74,7 @@ def main():
);
''')
if st.button('Save', type='primary'):
st.session_state.update_profile = True
origin_tables_info = profile_detail.tables_info
origin_table_info = origin_tables_info[selected_table]
origin_table_info['tbl_a'] = tbl_annotation
Expand Down
23 changes: 21 additions & 2 deletions application/pages/5_🪙_Prompt_Management.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,32 @@ def main():
if 'current_profile' not in st.session_state:
st.session_state['current_profile'] = ''

if "update_profile" not in st.session_state:
st.session_state.update_profile = False

if "profiles_list" not in st.session_state:
st.session_state["profiles_list"] = []

if 'profiles' not in st.session_state:
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state['profiles'] = all_profiles
st.session_state["profiles_list"] = list(all_profiles.keys())

if st.session_state.update_profile:
logger.info("session_state update_profile get_all_profiles_with_info")
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state["profiles_list"] = list(all_profiles.keys())
st.session_state['profiles'] = all_profiles
st.session_state.update_profile = False

with st.sidebar:
st.title("Prompt Management")
all_profiles_list = ProfileManagement.get_all_profiles()
all_profiles_list = st.session_state["profiles_list"]
if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list:
profile_index = all_profiles_list.index(st.session_state.current_profile)
current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index)
else:
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
current_profile = st.selectbox("My Data Profiles", all_profiles_list,
index=None,
placeholder="Please select data profile...", key='current_profile_name')

Expand Down Expand Up @@ -54,6 +72,7 @@ def main():

if st.button('Save', type='primary'):
# check prompt syntax, missing placeholder will cause backend execution failure
st.session_state.update_profile = True
if check_prompt_syntax(system_prompt_input, user_prompt_input,
prompt_type_selected_table, model_selected_table):
# assign new system/user prompt by selected model
Expand Down
22 changes: 20 additions & 2 deletions application/pages/6_📚_Index_Management.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,32 @@ def main():
if 'current_profile' not in st.session_state:
st.session_state['current_profile'] = ''

if "update_profile" not in st.session_state:
st.session_state.update_profile = False

if "profiles_list" not in st.session_state:
st.session_state["profiles_list"] = []

if 'profiles' not in st.session_state:
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state['profiles'] = all_profiles
st.session_state["profiles_list"] = list(all_profiles.keys())

if st.session_state.update_profile:
logger.info("session_state update_profile get_all_profiles_with_info")
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state["profiles_list"] = list(all_profiles.keys())
st.session_state['profiles'] = all_profiles
st.session_state.update_profile = False

with st.sidebar:
st.title("Index Management")
all_profiles_list = ProfileManagement.get_all_profiles()
all_profiles_list = st.session_state["profiles_list"]
if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list:
profile_index = all_profiles_list.index(st.session_state.current_profile)
current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index)
else:
current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(),
current_profile = st.selectbox("My Data Profiles", all_profiles_list,
index=None,
placeholder="Please select data profile...", key='current_profile_name')

Expand Down
Loading