Skip to content

change do_visualize_results #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 46 additions & 33 deletions application/pages/1_🌍_Generative_BI_Playground.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import json
import os
import streamlit as st
import pandas as pd
import plotly.express as px
from dotenv import load_dotenv
import logging
import random

from api.service import user_feedback_downvote
from nlq.business.connection import ConnectionManagement
Expand Down Expand Up @@ -65,6 +63,8 @@ def downvote_clicked(question, comment):
def clean_st_history(selected_profile):
st.session_state.messages[selected_profile] = []
st.session_state.query_rewrite_history[selected_profile] = []
st.session_state.current_sql_result[selected_profile] = None



def get_user_history(selected_profile: str):
Expand All @@ -80,41 +80,50 @@ def get_user_history(selected_profile: str):
history_query.append(messages["role"] + ":" + messages["content"])
return history_query

def set_vision_change():
st.session_state.vision_change = True


def do_visualize_results(nlq_chain, sql_result):
sql_query_result = sql_result
def do_visualize_results(selected_profile):
sql_query_result = st.session_state.current_sql_result[selected_profile]
if sql_query_result is not None:
nlq_chain.set_visualization_config_change(False)
# Auto-detect columns
visualize_config_columns = st.columns(3)
available_columns = sql_query_result.columns.tolist()

# Initialize session state for x_column and y_column if not already present
if 'x_column' not in st.session_state or st.session_state.x_column is None:
st.session_state.x_column = available_columns[0] if available_columns else None
if 'y_column' not in st.session_state or st.session_state.x_column is None:
st.session_state.y_column = available_columns[0] if available_columns else None

available_columns = sql_query_result.columns
# Layout configuration
col1, col2, col3 = st.columns([1, 1, 2])

# Chart type selection
chart_type = col1.selectbox('Choose the chart type', ['Table', 'Bar', 'Line', 'Pie'],
on_change=set_vision_change)

# hacky way to get around the issue of selectbox not updating when the options change
chart_type = visualize_config_columns[0].selectbox('Choose the chart type',
['Table', 'Bar', 'Line', 'Pie'],
on_change=nlq_chain.set_visualization_config_change
)
if chart_type != 'Table':
x_column = visualize_config_columns[1].selectbox(f'Choose x-axis column', available_columns,
on_change=nlq_chain.set_visualization_config_change,
key=random.randint(0, 10000)
)
y_column = visualize_config_columns[2].selectbox('Choose y-axis column',
reversed(available_columns.to_list()),
on_change=nlq_chain.set_visualization_config_change,
key=random.randint(0, 10000)
)
# X-axis and Y-axis selection
st.session_state.x_column = col2.selectbox('Choose x-axis column', available_columns,
on_change=set_vision_change,
index=available_columns.index(
st.session_state.x_column) if st.session_state.x_column in available_columns else 0)
st.session_state.y_column = col3.selectbox('Choose y-axis column', available_columns,
on_change=set_vision_change,
index=available_columns.index(
st.session_state.y_column) if st.session_state.y_column in available_columns else 0)

# Visualization
if chart_type == 'Table':
st.dataframe(sql_query_result, hide_index=True)
elif chart_type == 'Bar':
st.plotly_chart(px.bar(sql_query_result, x=x_column, y=y_column))
st.plotly_chart(px.bar(sql_query_result, x=st.session_state.x_column, y=st.session_state.y_column))
elif chart_type == 'Line':
st.plotly_chart(px.line(sql_query_result, x=x_column, y=y_column))
st.plotly_chart(px.line(sql_query_result, x=st.session_state.x_column, y=st.session_state.y_column))
elif chart_type == 'Pie':
st.plotly_chart(px.pie(sql_query_result, names=x_column, values=y_column))
else:
st.markdown('No visualization generated.')
st.plotly_chart(px.pie(sql_query_result, names=st.session_state.x_column, values=st.session_state.y_column))



def recurrent_display(messages, i):
Expand Down Expand Up @@ -247,6 +256,9 @@ def main():
all_profiles = ProfileManagement.get_all_profiles_with_info()
st.session_state['profiles'] = all_profiles

if "vision_change" not in st.session_state:
st.session_state["vision_change"] = False

if 'selected_sample' not in st.session_state:
st.session_state['selected_sample'] = ''

Expand Down Expand Up @@ -313,6 +325,9 @@ def main():
st.session_state.query_rewrite_history[selected_profile] = []
st.session_state.nlq_chain = NLQChain(selected_profile)

if selected_profile not in st.session_state.current_sql_result:
st.session_state.current_sql_result[selected_profile] = None

if st.session_state.current_model_id != "" and st.session_state.current_model_id in model_ids:
model_index = model_ids.index(st.session_state.current_model_id)
model_type = st.selectbox("Choose your model", model_ids, index=model_index)
Expand Down Expand Up @@ -384,9 +399,9 @@ def main():
knowledge_search_flag = False

# add select box for which model to use
if search_box != "Type your query here..." or \
current_nlq_chain.is_visualization_config_changed():
if search_box != "Type your query here..." or st.session_state.vision_change:
if search_box is not None and len(search_box) > 0:
st.session_state.current_sql_result[selected_profile] = None
with st.chat_message("user"):
current_nlq_chain.set_question(search_box)
st.session_state.messages[selected_profile].append(
Expand Down Expand Up @@ -642,7 +657,7 @@ def main():
st.session_state.messages[selected_profile].append(
{"role": "assistant", "content": current_search_sql_result, "type": "pandas"})

do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile])
do_visualize_results(selected_profile)
else:
st.markdown("No relevant data found")

Expand All @@ -666,10 +681,8 @@ def main():
on_click=sample_question_clicked,
args=[gen_sq_list[2]])
else:

if current_nlq_chain.is_visualization_config_changed():
if visualize_results_flag:
do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile])
if visualize_results_flag:
do_visualize_results(selected_profile)


if __name__ == '__main__':
Expand Down