Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 78 additions & 59 deletions application/pages/1_🌍_Natural_Language_Querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from nlq.business.nlq_chain import NLQChain
from nlq.business.profile import ProfileManagement
from utils.database import get_db_url_dialect
from nlq.business.vector_store import VectorStore
from utils.llm import claude3_to_sql, create_vector_embedding_with_bedrock, retrieve_results_from_opensearch, \
upload_results_to_opensearch
upload_results_to_opensearch, get_query_intent


def sample_question_clicked(sample):
"""Update the selected_sample variable with the text of the clicked button"""
Expand All @@ -20,19 +22,9 @@ def sample_question_clicked(sample):

def upvote_clicked(question, sql, env_vars):
# HACK: configurable opensearch endpoint
target_profile = 'shopping_guide'
aos_config = env_vars['data_sources'][target_profile]['opensearch']
upload_results_to_opensearch(
region_name=['region_name'],
domain=aos_config['domain'],
opensearch_user=aos_config['opensearch_user'],
opensearch_password=aos_config['opensearch_password'],
index_name=aos_config['index_name'],
query=question,
sql=sql,
host=aos_config['opensearch_host'],
port=aos_config['opensearch_port']
)

current_profile = st.session_state.current_profile
VectorStore.add_sample(current_profile, question, sql)
logger.info(f'up voted "{question}" with sql "{sql}"')


Expand Down Expand Up @@ -128,6 +120,9 @@ def main():
if 'nlq_chain' not in st.session_state:
st.session_state['nlq_chain'] = None

if "messages" not in st.session_state:
st.session_state.messages = {}

bedrock_model_ids = ['anthropic.claude-3-sonnet-20240229-v1:0', 'anthropic.claude-3-haiku-20240307-v1:0',
'anthropic.claude-v2:1']

Expand All @@ -139,7 +134,8 @@ def main():
# clear session state
st.session_state.selected_sample = ''
st.session_state.current_profile = selected_profile

if selected_profile not in st.session_state.messages:
st.session_state.messages[selected_profile] = []
st.session_state.nlq_chain = NLQChain(selected_profile)

st.session_state['option'] = st.selectbox("Choose your option", ["Text2SQL"])
Expand Down Expand Up @@ -174,7 +170,8 @@ def main():

# Display the predefined search samples as buttons within columns
for i, sample in enumerate(search_samples[0:question_column_number]):
search_sample_columns[i].button(sample, use_container_width=True, on_click=sample_question_clicked, args=[sample])
search_sample_columns[i].button(sample, use_container_width=True, on_click=sample_question_clicked,
args=[sample])

# Display more predefined search samples as buttons within columns, if there are more samples than columns
if len(search_samples) > question_column_number:
Expand All @@ -189,19 +186,32 @@ def main():
else:
col_num += 1

if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history
if selected_profile in st.session_state.messages:
for message in st.session_state.messages[selected_profile]:
with st.chat_message(message["role"]):
if "SQL:" in message["content"]:
st.code(message["content"].replace("SQL:", ""), language="sql")
elif isinstance(message["content"], pd.DataFrame):
st.table(message["content"])
else:
st.markdown(message["content"])

text_placeholder = "Type your query here..."

search_box = st.text_input('Search Box', value=st.session_state['selected_sample'],
placeholder='Type your query here...', max_chars=1000, key='search_box',
label_visibility='collapsed')
search_box = st.chat_input(placeholder=text_placeholder)
if st.session_state['selected_sample'] != "":
search_box = st.session_state['selected_sample']
st.session_state['selected_sample'] = ""

current_nlq_chain = st.session_state.nlq_chain

search_intent_flag = True

# add select box for which model to use
if st.button('Run', type='primary', use_container_width=True) or \
if search_box != "Type your query here..." or \
current_nlq_chain.is_visualization_config_changed():
if len(search_box) > 0:
if search_box is not None and len(search_box) > 0:
with st.chat_message("user"):
current_nlq_chain.set_question(search_box)
st.markdown(current_nlq_chain.get_question())
Expand Down Expand Up @@ -266,52 +276,61 @@ def main():
conn_name = database_profile['conn_name']
db_url = ConnectionManagement.get_db_url_by_name(conn_name)
database_profile['db_url'] = db_url
response = claude3_to_sql(database_profile['tables_info'],
database_profile['hints'],
search_box,
model_id=model_type,
examples=retrieve_result,
dialect=get_db_url_dialect(database_profile['db_url']),
model_provider=model_provider)

logger.info(f'got llm response: {response}')
current_nlq_chain.set_generated_sql_response(response)

intent_response = get_query_intent(model_type, search_box)

intent = intent_response.get("intent", "normal_search")
if intent == "reject_search":
search_intent_flag = False

if search_intent_flag:
response = claude3_to_sql(database_profile['tables_info'],
database_profile['hints'],
search_box,
model_id=model_type,
examples=retrieve_result,
dialect=get_db_url_dialect(database_profile['db_url']),
model_provider=model_provider)

logger.info(f'got llm response: {response}')
current_nlq_chain.set_generated_sql_response(response)
else:
logger.info('get generated sql from memory')

st.session_state.messages = []

# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": st.session_state['selected_sample']})
if search_intent_flag:
# Add user message to chat history
st.session_state.messages[selected_profile].append({"role": "user", "content": search_box})

# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content":
current_nlq_chain.get_generated_sql()})
st.session_state.messages.append({"role": "assistant", "content":
current_nlq_chain.get_generated_sql_explain()})
# Add assistant response to chat history
st.session_state.messages[selected_profile].append(
{"role": "assistant", "content": "SQL:" + current_nlq_chain.get_generated_sql()})
st.session_state.messages[selected_profile].append(
{"role": "assistant", "content": current_nlq_chain.get_generated_sql_explain()})

st.markdown('The generated SQL statement is:')
st.code(current_nlq_chain.get_generated_sql(), language="sql")
st.markdown('The generated SQL statement is:')
st.code(current_nlq_chain.get_generated_sql(), language="sql")

st.markdown('Generation process explanations:')
st.markdown(current_nlq_chain.get_generated_sql_explain())
st.markdown('Generation process explanations:')
st.markdown(current_nlq_chain.get_generated_sql_explain())

st.markdown('You can provide feedback:')
st.markdown('You can provide feedback:')

# add a upvote(green)/downvote button with logo
feedback = st.columns(2)
feedback[0].button('👍 Upvote (save as embedding for retrieval)', type='secondary',
use_container_width=True,
on_click=upvote_clicked,
args=[current_nlq_chain.get_question(),
current_nlq_chain.get_generated_sql(),
env_vars])
# add a upvote(green)/downvote button with logo
feedback = st.columns(2)
feedback[0].button('👍 Upvote (save as embedding for retrieval)', type='secondary',
use_container_width=True,
on_click=upvote_clicked,
args=[current_nlq_chain.get_question(),
current_nlq_chain.get_generated_sql(),
env_vars])

if feedback[1].button('👎 Downvote', type='secondary', use_container_width=True):
# do something here
pass
if feedback[1].button('👎 Downvote', type='secondary', use_container_width=True):
# do something here
pass
else:
st.markdown('Your query statement is currently not supported by the system')

if visualize_results:
if visualize_results and search_intent_flag:
do_visualize_results(current_nlq_chain)
else:
st.error("Please enter a valid query.")
Expand Down
2 changes: 2 additions & 0 deletions application/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ PyMySQL==1.1.0
python-dotenv~=1.0.0
plotly~=5.18.0
cryptography==42.0.4
langchain~=0.1.11
langchain-core~=0.1.30
20 changes: 19 additions & 1 deletion application/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from opensearchpy import OpenSearch
from utils import opensearch
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
DEFAULT_DIALECT_PROMPT
DEFAULT_DIALECT_PROMPT, SEARCH_INTENT_PROMPT_CLAUDE3
import os
from loguru import logger
from langchain_core.output_parsers import JsonOutputParser

BEDROCK_AWS_REGION = os.environ.get('BEDROCK_REGION', 'us-west-2')

Expand All @@ -24,6 +25,7 @@
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html

bedrock = None
json_parse = JsonOutputParser()


@logger.catch
Expand Down Expand Up @@ -123,6 +125,22 @@ def claude3_to_sql(ddl, hints, search_box, examples=None, model_id=None, dialect
return final_response


def get_query_intent(model_id, search_box):
default_intent = {"intent": "normal_search"}
try:
system_prompt = SEARCH_INTENT_PROMPT_CLAUDE3
max_tokens = 2048
user_message = {"role": "user", "content": search_box}
messages = [user_message]
response = invoke_model_claude3(model_id, system_prompt, messages, max_tokens)
final_response = response.get("content")[0].get("text")
logger.info(f'{final_response=}')
intent_result_dict = json_parse.parse(final_response)
return intent_result_dict
except Exception as e:
return default_intent


def create_vector_embedding_with_bedrock(text, index_name):
payload = {"inputText": f"{text}"}
body = json.dumps(payload)
Expand Down
38 changes: 38 additions & 0 deletions application/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,41 @@
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today". Aside from giving the SQL answer, concisely explain yourself after giving the answer
in the same language as the question.""".format(top_k=TOP_K)

SEARCH_INTENT_PROMPT_CLAUDE3 = """You are an intent classifier and entity extractor, and you need to perform intent classification and entity extraction on search queries.
Background: I want to query data in the database, and you need to help me determine the user's relevant intent and extract the keywords from the query statement. Finally, return a JSON structure.

There are 2 main intents:
<intent>
- normal_search: Query relevant data from the data table
- reject_search: Delete data from the table, add data to the table, modify data in the table, display usernames and passwords in the table, and other topics unrelated to data query
</intent>

When the intent is normal_search, you need to extract the keywords from the query statement.

Here are some examples:

<example>
question : 希尔顿在欧洲上线了多少酒店数
answer :
{
"intent" : "normal_search",
"slot" : ["希尔顿", "欧洲", "上线", "酒店数"]
}

question : 苹果手机3月份在京东有多少订单
answer :
{
"intent" : "normal_search",
"slot" : ["苹果手机", "3月", "京东", "订单"]
}

question : 修改订单表中的第一行数据
answer :
{
"intent" : "reject_search"
}
</example>

Please perform intent recognition and entity extraction. Return only the JSON structure, without any other annotations.
"""