From 9daf1cb1eca8c9abcc10a1927d6f1e42828fb6d9 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Mon, 19 Aug 2024 15:08:09 +0800 Subject: [PATCH] change do_visualize_results --- ...0\237\214\215_Generative_BI_Playground.py" | 79 +++++++++++-------- 1 file changed, 46 insertions(+), 33 deletions(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index 1848a8ba..6e0f6e0d 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -1,11 +1,9 @@ import json -import os import streamlit as st import pandas as pd import plotly.express as px from dotenv import load_dotenv import logging -import random from api.service import user_feedback_downvote from nlq.business.connection import ConnectionManagement @@ -65,6 +63,8 @@ def downvote_clicked(question, comment): def clean_st_history(selected_profile): st.session_state.messages[selected_profile] = [] st.session_state.query_rewrite_history[selected_profile] = [] + st.session_state.current_sql_result[selected_profile] = None + def get_user_history(selected_profile: str): @@ -80,41 +80,50 @@ def get_user_history(selected_profile: str): history_query.append(messages["role"] + ":" + messages["content"]) return history_query +def set_vision_change(): + st.session_state.vision_change = True + -def do_visualize_results(nlq_chain, sql_result): - sql_query_result = sql_result +def do_visualize_results(selected_profile): + sql_query_result = st.session_state.current_sql_result[selected_profile] if sql_query_result is not None: - nlq_chain.set_visualization_config_change(False) # Auto-detect columns - visualize_config_columns = st.columns(3) + available_columns = sql_query_result.columns.tolist() + + # Initialize session state for x_column and y_column if not already present + if 'x_column' not in st.session_state or st.session_state.x_column is None: + st.session_state.x_column = available_columns[0] if available_columns else None + if 'y_column' not in st.session_state or st.session_state.x_column is None: + st.session_state.y_column = available_columns[0] if available_columns else None - available_columns = sql_query_result.columns + # Layout configuration + col1, col2, col3 = st.columns([1, 1, 2]) + + # Chart type selection + chart_type = col1.selectbox('Choose the chart type', ['Table', 'Bar', 'Line', 'Pie'], + on_change=set_vision_change) - # hacky way to get around the issue of selectbox not updating when the options change - chart_type = visualize_config_columns[0].selectbox('Choose the chart type', - ['Table', 'Bar', 'Line', 'Pie'], - on_change=nlq_chain.set_visualization_config_change - ) if chart_type != 'Table': - x_column = visualize_config_columns[1].selectbox(f'Choose x-axis column', available_columns, - on_change=nlq_chain.set_visualization_config_change, - key=random.randint(0, 10000) - ) - y_column = visualize_config_columns[2].selectbox('Choose y-axis column', - reversed(available_columns.to_list()), - on_change=nlq_chain.set_visualization_config_change, - key=random.randint(0, 10000) - ) + # X-axis and Y-axis selection + st.session_state.x_column = col2.selectbox('Choose x-axis column', available_columns, + on_change=set_vision_change, + index=available_columns.index( + st.session_state.x_column) if st.session_state.x_column in available_columns else 0) + st.session_state.y_column = col3.selectbox('Choose y-axis column', available_columns, + on_change=set_vision_change, + index=available_columns.index( + st.session_state.y_column) if st.session_state.y_column in available_columns else 0) + + # Visualization if chart_type == 'Table': st.dataframe(sql_query_result, hide_index=True) elif chart_type == 'Bar': - st.plotly_chart(px.bar(sql_query_result, x=x_column, y=y_column)) + st.plotly_chart(px.bar(sql_query_result, x=st.session_state.x_column, y=st.session_state.y_column)) elif chart_type == 'Line': - st.plotly_chart(px.line(sql_query_result, x=x_column, y=y_column)) + st.plotly_chart(px.line(sql_query_result, x=st.session_state.x_column, y=st.session_state.y_column)) elif chart_type == 'Pie': - st.plotly_chart(px.pie(sql_query_result, names=x_column, values=y_column)) - else: - st.markdown('No visualization generated.') + st.plotly_chart(px.pie(sql_query_result, names=st.session_state.x_column, values=st.session_state.y_column)) + def recurrent_display(messages, i): @@ -247,6 +256,9 @@ def main(): all_profiles = ProfileManagement.get_all_profiles_with_info() st.session_state['profiles'] = all_profiles + if "vision_change" not in st.session_state: + st.session_state["vision_change"] = False + if 'selected_sample' not in st.session_state: st.session_state['selected_sample'] = '' @@ -313,6 +325,9 @@ def main(): st.session_state.query_rewrite_history[selected_profile] = [] st.session_state.nlq_chain = NLQChain(selected_profile) + if selected_profile not in st.session_state.current_sql_result: + st.session_state.current_sql_result[selected_profile] = None + if st.session_state.current_model_id != "" and st.session_state.current_model_id in model_ids: model_index = model_ids.index(st.session_state.current_model_id) model_type = st.selectbox("Choose your model", model_ids, index=model_index) @@ -384,9 +399,9 @@ def main(): knowledge_search_flag = False # add select box for which model to use - if search_box != "Type your query here..." or \ - current_nlq_chain.is_visualization_config_changed(): + if search_box != "Type your query here..." or st.session_state.vision_change: if search_box is not None and len(search_box) > 0: + st.session_state.current_sql_result[selected_profile] = None with st.chat_message("user"): current_nlq_chain.set_question(search_box) st.session_state.messages[selected_profile].append( @@ -642,7 +657,7 @@ def main(): st.session_state.messages[selected_profile].append( {"role": "assistant", "content": current_search_sql_result, "type": "pandas"}) - do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile]) + do_visualize_results(selected_profile) else: st.markdown("No relevant data found") @@ -666,10 +681,8 @@ def main(): on_click=sample_question_clicked, args=[gen_sq_list[2]]) else: - - if current_nlq_chain.is_visualization_config_changed(): - if visualize_results_flag: - do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile]) + if visualize_results_flag: + do_visualize_results(selected_profile) if __name__ == '__main__':