1
1
import json
2
- import os
3
2
import streamlit as st
4
3
import pandas as pd
5
4
import plotly .express as px
6
5
from dotenv import load_dotenv
7
6
import logging
8
- import random
9
7
10
8
from api .service import user_feedback_downvote
11
9
from nlq .business .connection import ConnectionManagement
@@ -65,6 +63,8 @@ def downvote_clicked(question, comment):
65
63
def clean_st_history (selected_profile ):
66
64
st .session_state .messages [selected_profile ] = []
67
65
st .session_state .query_rewrite_history [selected_profile ] = []
66
+ st .session_state .current_sql_result [selected_profile ] = None
67
+
68
68
69
69
70
70
def get_user_history (selected_profile : str ):
@@ -80,41 +80,50 @@ def get_user_history(selected_profile: str):
80
80
history_query .append (messages ["role" ] + ":" + messages ["content" ])
81
81
return history_query
82
82
83
+ def set_vision_change ():
84
+ st .session_state .vision_change = True
85
+
83
86
84
- def do_visualize_results (nlq_chain , sql_result ):
85
- sql_query_result = sql_result
87
+ def do_visualize_results (selected_profile ):
88
+ sql_query_result = st . session_state . current_sql_result [ selected_profile ]
86
89
if sql_query_result is not None :
87
- nlq_chain .set_visualization_config_change (False )
88
90
# Auto-detect columns
89
- visualize_config_columns = st .columns (3 )
91
+ available_columns = sql_query_result .columns .tolist ()
92
+
93
+ # Initialize session state for x_column and y_column if not already present
94
+ if 'x_column' not in st .session_state or st .session_state .x_column is None :
95
+ st .session_state .x_column = available_columns [0 ] if available_columns else None
96
+ if 'y_column' not in st .session_state or st .session_state .x_column is None :
97
+ st .session_state .y_column = available_columns [0 ] if available_columns else None
90
98
91
- available_columns = sql_query_result .columns
99
+ # Layout configuration
100
+ col1 , col2 , col3 = st .columns ([1 , 1 , 2 ])
101
+
102
+ # Chart type selection
103
+ chart_type = col1 .selectbox ('Choose the chart type' , ['Table' , 'Bar' , 'Line' , 'Pie' ],
104
+ on_change = set_vision_change )
92
105
93
- # hacky way to get around the issue of selectbox not updating when the options change
94
- chart_type = visualize_config_columns [0 ].selectbox ('Choose the chart type' ,
95
- ['Table' , 'Bar' , 'Line' , 'Pie' ],
96
- on_change = nlq_chain .set_visualization_config_change
97
- )
98
106
if chart_type != 'Table' :
99
- x_column = visualize_config_columns [1 ].selectbox (f'Choose x-axis column' , available_columns ,
100
- on_change = nlq_chain .set_visualization_config_change ,
101
- key = random .randint (0 , 10000 )
102
- )
103
- y_column = visualize_config_columns [2 ].selectbox ('Choose y-axis column' ,
104
- reversed (available_columns .to_list ()),
105
- on_change = nlq_chain .set_visualization_config_change ,
106
- key = random .randint (0 , 10000 )
107
- )
107
+ # X-axis and Y-axis selection
108
+ st .session_state .x_column = col2 .selectbox ('Choose x-axis column' , available_columns ,
109
+ on_change = set_vision_change ,
110
+ index = available_columns .index (
111
+ st .session_state .x_column ) if st .session_state .x_column in available_columns else 0 )
112
+ st .session_state .y_column = col3 .selectbox ('Choose y-axis column' , available_columns ,
113
+ on_change = set_vision_change ,
114
+ index = available_columns .index (
115
+ st .session_state .y_column ) if st .session_state .y_column in available_columns else 0 )
116
+
117
+ # Visualization
108
118
if chart_type == 'Table' :
109
119
st .dataframe (sql_query_result , hide_index = True )
110
120
elif chart_type == 'Bar' :
111
- st .plotly_chart (px .bar (sql_query_result , x = x_column , y = y_column ))
121
+ st .plotly_chart (px .bar (sql_query_result , x = st . session_state . x_column , y = st . session_state . y_column ))
112
122
elif chart_type == 'Line' :
113
- st .plotly_chart (px .line (sql_query_result , x = x_column , y = y_column ))
123
+ st .plotly_chart (px .line (sql_query_result , x = st . session_state . x_column , y = st . session_state . y_column ))
114
124
elif chart_type == 'Pie' :
115
- st .plotly_chart (px .pie (sql_query_result , names = x_column , values = y_column ))
116
- else :
117
- st .markdown ('No visualization generated.' )
125
+ st .plotly_chart (px .pie (sql_query_result , names = st .session_state .x_column , values = st .session_state .y_column ))
126
+
118
127
119
128
120
129
def recurrent_display (messages , i ):
@@ -247,6 +256,9 @@ def main():
247
256
all_profiles = ProfileManagement .get_all_profiles_with_info ()
248
257
st .session_state ['profiles' ] = all_profiles
249
258
259
+ if "vision_change" not in st .session_state :
260
+ st .session_state ["vision_change" ] = False
261
+
250
262
if 'selected_sample' not in st .session_state :
251
263
st .session_state ['selected_sample' ] = ''
252
264
@@ -313,6 +325,9 @@ def main():
313
325
st .session_state .query_rewrite_history [selected_profile ] = []
314
326
st .session_state .nlq_chain = NLQChain (selected_profile )
315
327
328
+ if selected_profile not in st .session_state .current_sql_result :
329
+ st .session_state .current_sql_result [selected_profile ] = None
330
+
316
331
if st .session_state .current_model_id != "" and st .session_state .current_model_id in model_ids :
317
332
model_index = model_ids .index (st .session_state .current_model_id )
318
333
model_type = st .selectbox ("Choose your model" , model_ids , index = model_index )
@@ -384,9 +399,9 @@ def main():
384
399
knowledge_search_flag = False
385
400
386
401
# add select box for which model to use
387
- if search_box != "Type your query here..." or \
388
- current_nlq_chain .is_visualization_config_changed ():
402
+ if search_box != "Type your query here..." or st .session_state .vision_change :
389
403
if search_box is not None and len (search_box ) > 0 :
404
+ st .session_state .current_sql_result [selected_profile ] = None
390
405
with st .chat_message ("user" ):
391
406
current_nlq_chain .set_question (search_box )
392
407
st .session_state .messages [selected_profile ].append (
@@ -642,7 +657,7 @@ def main():
642
657
st .session_state .messages [selected_profile ].append (
643
658
{"role" : "assistant" , "content" : current_search_sql_result , "type" : "pandas" })
644
659
645
- do_visualize_results (current_nlq_chain , st . session_state . current_sql_result [ selected_profile ] )
660
+ do_visualize_results (selected_profile )
646
661
else :
647
662
st .markdown ("No relevant data found" )
648
663
@@ -666,10 +681,8 @@ def main():
666
681
on_click = sample_question_clicked ,
667
682
args = [gen_sq_list [2 ]])
668
683
else :
669
-
670
- if current_nlq_chain .is_visualization_config_changed ():
671
- if visualize_results_flag :
672
- do_visualize_results (current_nlq_chain , st .session_state .current_sql_result [selected_profile ])
684
+ if visualize_results_flag :
685
+ do_visualize_results (selected_profile )
673
686
674
687
675
688
if __name__ == '__main__' :
0 commit comments