Skip to content

Commit 9daf1cb

Browse files
author
Pinyu Su
committed
change do_visualize_results
1 parent fbe92d7 commit 9daf1cb

File tree

1 file changed

+46
-33
lines changed

1 file changed

+46
-33
lines changed

application/pages/1_🌍_Generative_BI_Playground.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import json
2-
import os
32
import streamlit as st
43
import pandas as pd
54
import plotly.express as px
65
from dotenv import load_dotenv
76
import logging
8-
import random
97

108
from api.service import user_feedback_downvote
119
from nlq.business.connection import ConnectionManagement
@@ -65,6 +63,8 @@ def downvote_clicked(question, comment):
6563
def clean_st_history(selected_profile):
6664
st.session_state.messages[selected_profile] = []
6765
st.session_state.query_rewrite_history[selected_profile] = []
66+
st.session_state.current_sql_result[selected_profile] = None
67+
6868

6969

7070
def get_user_history(selected_profile: str):
@@ -80,41 +80,50 @@ def get_user_history(selected_profile: str):
8080
history_query.append(messages["role"] + ":" + messages["content"])
8181
return history_query
8282

83+
def set_vision_change():
84+
st.session_state.vision_change = True
85+
8386

84-
def do_visualize_results(nlq_chain, sql_result):
85-
sql_query_result = sql_result
87+
def do_visualize_results(selected_profile):
88+
sql_query_result = st.session_state.current_sql_result[selected_profile]
8689
if sql_query_result is not None:
87-
nlq_chain.set_visualization_config_change(False)
8890
# Auto-detect columns
89-
visualize_config_columns = st.columns(3)
91+
available_columns = sql_query_result.columns.tolist()
92+
93+
# Initialize session state for x_column and y_column if not already present
94+
if 'x_column' not in st.session_state or st.session_state.x_column is None:
95+
st.session_state.x_column = available_columns[0] if available_columns else None
96+
if 'y_column' not in st.session_state or st.session_state.x_column is None:
97+
st.session_state.y_column = available_columns[0] if available_columns else None
9098

91-
available_columns = sql_query_result.columns
99+
# Layout configuration
100+
col1, col2, col3 = st.columns([1, 1, 2])
101+
102+
# Chart type selection
103+
chart_type = col1.selectbox('Choose the chart type', ['Table', 'Bar', 'Line', 'Pie'],
104+
on_change=set_vision_change)
92105

93-
# hacky way to get around the issue of selectbox not updating when the options change
94-
chart_type = visualize_config_columns[0].selectbox('Choose the chart type',
95-
['Table', 'Bar', 'Line', 'Pie'],
96-
on_change=nlq_chain.set_visualization_config_change
97-
)
98106
if chart_type != 'Table':
99-
x_column = visualize_config_columns[1].selectbox(f'Choose x-axis column', available_columns,
100-
on_change=nlq_chain.set_visualization_config_change,
101-
key=random.randint(0, 10000)
102-
)
103-
y_column = visualize_config_columns[2].selectbox('Choose y-axis column',
104-
reversed(available_columns.to_list()),
105-
on_change=nlq_chain.set_visualization_config_change,
106-
key=random.randint(0, 10000)
107-
)
107+
# X-axis and Y-axis selection
108+
st.session_state.x_column = col2.selectbox('Choose x-axis column', available_columns,
109+
on_change=set_vision_change,
110+
index=available_columns.index(
111+
st.session_state.x_column) if st.session_state.x_column in available_columns else 0)
112+
st.session_state.y_column = col3.selectbox('Choose y-axis column', available_columns,
113+
on_change=set_vision_change,
114+
index=available_columns.index(
115+
st.session_state.y_column) if st.session_state.y_column in available_columns else 0)
116+
117+
# Visualization
108118
if chart_type == 'Table':
109119
st.dataframe(sql_query_result, hide_index=True)
110120
elif chart_type == 'Bar':
111-
st.plotly_chart(px.bar(sql_query_result, x=x_column, y=y_column))
121+
st.plotly_chart(px.bar(sql_query_result, x=st.session_state.x_column, y=st.session_state.y_column))
112122
elif chart_type == 'Line':
113-
st.plotly_chart(px.line(sql_query_result, x=x_column, y=y_column))
123+
st.plotly_chart(px.line(sql_query_result, x=st.session_state.x_column, y=st.session_state.y_column))
114124
elif chart_type == 'Pie':
115-
st.plotly_chart(px.pie(sql_query_result, names=x_column, values=y_column))
116-
else:
117-
st.markdown('No visualization generated.')
125+
st.plotly_chart(px.pie(sql_query_result, names=st.session_state.x_column, values=st.session_state.y_column))
126+
118127

119128

120129
def recurrent_display(messages, i):
@@ -247,6 +256,9 @@ def main():
247256
all_profiles = ProfileManagement.get_all_profiles_with_info()
248257
st.session_state['profiles'] = all_profiles
249258

259+
if "vision_change" not in st.session_state:
260+
st.session_state["vision_change"] = False
261+
250262
if 'selected_sample' not in st.session_state:
251263
st.session_state['selected_sample'] = ''
252264

@@ -313,6 +325,9 @@ def main():
313325
st.session_state.query_rewrite_history[selected_profile] = []
314326
st.session_state.nlq_chain = NLQChain(selected_profile)
315327

328+
if selected_profile not in st.session_state.current_sql_result:
329+
st.session_state.current_sql_result[selected_profile] = None
330+
316331
if st.session_state.current_model_id != "" and st.session_state.current_model_id in model_ids:
317332
model_index = model_ids.index(st.session_state.current_model_id)
318333
model_type = st.selectbox("Choose your model", model_ids, index=model_index)
@@ -384,9 +399,9 @@ def main():
384399
knowledge_search_flag = False
385400

386401
# add select box for which model to use
387-
if search_box != "Type your query here..." or \
388-
current_nlq_chain.is_visualization_config_changed():
402+
if search_box != "Type your query here..." or st.session_state.vision_change:
389403
if search_box is not None and len(search_box) > 0:
404+
st.session_state.current_sql_result[selected_profile] = None
390405
with st.chat_message("user"):
391406
current_nlq_chain.set_question(search_box)
392407
st.session_state.messages[selected_profile].append(
@@ -642,7 +657,7 @@ def main():
642657
st.session_state.messages[selected_profile].append(
643658
{"role": "assistant", "content": current_search_sql_result, "type": "pandas"})
644659

645-
do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile])
660+
do_visualize_results(selected_profile)
646661
else:
647662
st.markdown("No relevant data found")
648663

@@ -666,10 +681,8 @@ def main():
666681
on_click=sample_question_clicked,
667682
args=[gen_sq_list[2]])
668683
else:
669-
670-
if current_nlq_chain.is_visualization_config_changed():
671-
if visualize_results_flag:
672-
do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile])
684+
if visualize_results_flag:
685+
do_visualize_results(selected_profile)
673686

674687

675688
if __name__ == '__main__':

0 commit comments

Comments
 (0)