Skip to content

Commit 7bc2e0f

Browse files
authored
Merge pull request #5 from aws-samples/intent
add intent code and change text input
2 parents f46953a + 4cd06fb commit 7bc2e0f

File tree

4 files changed

+137
-60
lines changed

4 files changed

+137
-60
lines changed

application/pages/1_🌍_Natural_Language_Querying.py

Lines changed: 78 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from nlq.business.nlq_chain import NLQChain
1111
from nlq.business.profile import ProfileManagement
1212
from utils.database import get_db_url_dialect
13+
from nlq.business.vector_store import VectorStore
1314
from utils.llm import claude3_to_sql, create_vector_embedding_with_bedrock, retrieve_results_from_opensearch, \
14-
upload_results_to_opensearch
15+
upload_results_to_opensearch, get_query_intent
16+
1517

1618
def sample_question_clicked(sample):
1719
"""Update the selected_sample variable with the text of the clicked button"""
@@ -20,19 +22,9 @@ def sample_question_clicked(sample):
2022

2123
def upvote_clicked(question, sql, env_vars):
2224
# HACK: configurable opensearch endpoint
23-
target_profile = 'shopping_guide'
24-
aos_config = env_vars['data_sources'][target_profile]['opensearch']
25-
upload_results_to_opensearch(
26-
region_name=['region_name'],
27-
domain=aos_config['domain'],
28-
opensearch_user=aos_config['opensearch_user'],
29-
opensearch_password=aos_config['opensearch_password'],
30-
index_name=aos_config['index_name'],
31-
query=question,
32-
sql=sql,
33-
host=aos_config['opensearch_host'],
34-
port=aos_config['opensearch_port']
35-
)
25+
26+
current_profile = st.session_state.current_profile
27+
VectorStore.add_sample(current_profile, question, sql)
3628
logger.info(f'up voted "{question}" with sql "{sql}"')
3729

3830

@@ -128,6 +120,9 @@ def main():
128120
if 'nlq_chain' not in st.session_state:
129121
st.session_state['nlq_chain'] = None
130122

123+
if "messages" not in st.session_state:
124+
st.session_state.messages = {}
125+
131126
bedrock_model_ids = ['anthropic.claude-3-sonnet-20240229-v1:0', 'anthropic.claude-3-haiku-20240307-v1:0',
132127
'anthropic.claude-v2:1']
133128

@@ -139,7 +134,8 @@ def main():
139134
# clear session state
140135
st.session_state.selected_sample = ''
141136
st.session_state.current_profile = selected_profile
142-
137+
if selected_profile not in st.session_state.messages:
138+
st.session_state.messages[selected_profile] = []
143139
st.session_state.nlq_chain = NLQChain(selected_profile)
144140

145141
st.session_state['option'] = st.selectbox("Choose your option", ["Text2SQL"])
@@ -174,7 +170,8 @@ def main():
174170

175171
# Display the predefined search samples as buttons within columns
176172
for i, sample in enumerate(search_samples[0:question_column_number]):
177-
search_sample_columns[i].button(sample, use_container_width=True, on_click=sample_question_clicked, args=[sample])
173+
search_sample_columns[i].button(sample, use_container_width=True, on_click=sample_question_clicked,
174+
args=[sample])
178175

179176
# Display more predefined search samples as buttons within columns, if there are more samples than columns
180177
if len(search_samples) > question_column_number:
@@ -189,19 +186,32 @@ def main():
189186
else:
190187
col_num += 1
191188

192-
if "messages" not in st.session_state:
193-
st.session_state.messages = []
189+
# Display chat messages from history
190+
if selected_profile in st.session_state.messages:
191+
for message in st.session_state.messages[selected_profile]:
192+
with st.chat_message(message["role"]):
193+
if "SQL:" in message["content"]:
194+
st.code(message["content"].replace("SQL:", ""), language="sql")
195+
elif isinstance(message["content"], pd.DataFrame):
196+
st.table(message["content"])
197+
else:
198+
st.markdown(message["content"])
199+
200+
text_placeholder = "Type your query here..."
194201

195-
search_box = st.text_input('Search Box', value=st.session_state['selected_sample'],
196-
placeholder='Type your query here...', max_chars=1000, key='search_box',
197-
label_visibility='collapsed')
202+
search_box = st.chat_input(placeholder=text_placeholder)
203+
if st.session_state['selected_sample'] != "":
204+
search_box = st.session_state['selected_sample']
205+
st.session_state['selected_sample'] = ""
198206

199207
current_nlq_chain = st.session_state.nlq_chain
200208

209+
search_intent_flag = True
210+
201211
# add select box for which model to use
202-
if st.button('Run', type='primary', use_container_width=True) or \
212+
if search_box != "Type your query here..." or \
203213
current_nlq_chain.is_visualization_config_changed():
204-
if len(search_box) > 0:
214+
if search_box is not None and len(search_box) > 0:
205215
with st.chat_message("user"):
206216
current_nlq_chain.set_question(search_box)
207217
st.markdown(current_nlq_chain.get_question())
@@ -266,52 +276,61 @@ def main():
266276
conn_name = database_profile['conn_name']
267277
db_url = ConnectionManagement.get_db_url_by_name(conn_name)
268278
database_profile['db_url'] = db_url
269-
response = claude3_to_sql(database_profile['tables_info'],
270-
database_profile['hints'],
271-
search_box,
272-
model_id=model_type,
273-
examples=retrieve_result,
274-
dialect=get_db_url_dialect(database_profile['db_url']),
275-
model_provider=model_provider)
276-
277-
logger.info(f'got llm response: {response}')
278-
current_nlq_chain.set_generated_sql_response(response)
279+
280+
intent_response = get_query_intent(model_type, search_box)
281+
282+
intent = intent_response.get("intent", "normal_search")
283+
if intent == "reject_search":
284+
search_intent_flag = False
285+
286+
if search_intent_flag:
287+
response = claude3_to_sql(database_profile['tables_info'],
288+
database_profile['hints'],
289+
search_box,
290+
model_id=model_type,
291+
examples=retrieve_result,
292+
dialect=get_db_url_dialect(database_profile['db_url']),
293+
model_provider=model_provider)
294+
295+
logger.info(f'got llm response: {response}')
296+
current_nlq_chain.set_generated_sql_response(response)
279297
else:
280298
logger.info('get generated sql from memory')
281299

282-
st.session_state.messages = []
283-
284-
# Add user message to chat history
285-
st.session_state.messages.append({"role": "user", "content": st.session_state['selected_sample']})
300+
if search_intent_flag:
301+
# Add user message to chat history
302+
st.session_state.messages[selected_profile].append({"role": "user", "content": search_box})
286303

287-
# Add assistant response to chat history
288-
st.session_state.messages.append({"role": "assistant", "content":
289-
current_nlq_chain.get_generated_sql()})
290-
st.session_state.messages.append({"role": "assistant", "content":
291-
current_nlq_chain.get_generated_sql_explain()})
304+
# Add assistant response to chat history
305+
st.session_state.messages[selected_profile].append(
306+
{"role": "assistant", "content": "SQL:" + current_nlq_chain.get_generated_sql()})
307+
st.session_state.messages[selected_profile].append(
308+
{"role": "assistant", "content": current_nlq_chain.get_generated_sql_explain()})
292309

293-
st.markdown('The generated SQL statement is:')
294-
st.code(current_nlq_chain.get_generated_sql(), language="sql")
310+
st.markdown('The generated SQL statement is:')
311+
st.code(current_nlq_chain.get_generated_sql(), language="sql")
295312

296-
st.markdown('Generation process explanations:')
297-
st.markdown(current_nlq_chain.get_generated_sql_explain())
313+
st.markdown('Generation process explanations:')
314+
st.markdown(current_nlq_chain.get_generated_sql_explain())
298315

299-
st.markdown('You can provide feedback:')
316+
st.markdown('You can provide feedback:')
300317

301-
# add a upvote(green)/downvote button with logo
302-
feedback = st.columns(2)
303-
feedback[0].button('👍 Upvote (save as embedding for retrieval)', type='secondary',
304-
use_container_width=True,
305-
on_click=upvote_clicked,
306-
args=[current_nlq_chain.get_question(),
307-
current_nlq_chain.get_generated_sql(),
308-
env_vars])
318+
# add a upvote(green)/downvote button with logo
319+
feedback = st.columns(2)
320+
feedback[0].button('👍 Upvote (save as embedding for retrieval)', type='secondary',
321+
use_container_width=True,
322+
on_click=upvote_clicked,
323+
args=[current_nlq_chain.get_question(),
324+
current_nlq_chain.get_generated_sql(),
325+
env_vars])
309326

310-
if feedback[1].button('👎 Downvote', type='secondary', use_container_width=True):
311-
# do something here
312-
pass
327+
if feedback[1].button('👎 Downvote', type='secondary', use_container_width=True):
328+
# do something here
329+
pass
330+
else:
331+
st.markdown('Your query statement is currently not supported by the system')
313332

314-
if visualize_results:
333+
if visualize_results and search_intent_flag:
315334
do_visualize_results(current_nlq_chain)
316335
else:
317336
st.error("Please enter a valid query.")

application/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ PyMySQL==1.1.0
88
python-dotenv~=1.0.0
99
plotly~=5.18.0
1010
cryptography==42.0.4
11+
langchain~=0.1.11
12+
langchain-core~=0.1.30

application/utils/llm.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from opensearchpy import OpenSearch
77
from utils import opensearch
88
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
9-
DEFAULT_DIALECT_PROMPT
9+
DEFAULT_DIALECT_PROMPT, SEARCH_INTENT_PROMPT_CLAUDE3
1010
import os
1111
from loguru import logger
12+
from langchain_core.output_parsers import JsonOutputParser
1213

1314
BEDROCK_AWS_REGION = os.environ.get('BEDROCK_REGION', 'us-west-2')
1415

@@ -24,6 +25,7 @@
2425
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html
2526

2627
bedrock = None
28+
json_parse = JsonOutputParser()
2729

2830

2931
@logger.catch
@@ -123,6 +125,22 @@ def claude3_to_sql(ddl, hints, search_box, examples=None, model_id=None, dialect
123125
return final_response
124126

125127

128+
def get_query_intent(model_id, search_box):
129+
default_intent = {"intent": "normal_search"}
130+
try:
131+
system_prompt = SEARCH_INTENT_PROMPT_CLAUDE3
132+
max_tokens = 2048
133+
user_message = {"role": "user", "content": search_box}
134+
messages = [user_message]
135+
response = invoke_model_claude3(model_id, system_prompt, messages, max_tokens)
136+
final_response = response.get("content")[0].get("text")
137+
logger.info(f'{final_response=}')
138+
intent_result_dict = json_parse.parse(final_response)
139+
return intent_result_dict
140+
except Exception as e:
141+
return default_intent
142+
143+
126144
def create_vector_embedding_with_bedrock(text, index_name):
127145
payload = {"inputText": f"{text}"}
128146
body = json.dumps(payload)

application/utils/prompt.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,41 @@
1515
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
1616
Pay attention to use CURDATE() function to get the current date, if the question involves "today". Aside from giving the SQL answer, concisely explain yourself after giving the answer
1717
in the same language as the question.""".format(top_k=TOP_K)
18+
19+
SEARCH_INTENT_PROMPT_CLAUDE3 = """You are an intent classifier and entity extractor, and you need to perform intent classification and entity extraction on search queries.
20+
Background: I want to query data in the database, and you need to help me determine the user's relevant intent and extract the keywords from the query statement. Finally, return a JSON structure.
21+
22+
There are 2 main intents:
23+
<intent>
24+
- normal_search: Query relevant data from the data table
25+
- reject_search: Delete data from the table, add data to the table, modify data in the table, display usernames and passwords in the table, and other topics unrelated to data query
26+
</intent>
27+
28+
When the intent is normal_search, you need to extract the keywords from the query statement.
29+
30+
Here are some examples:
31+
32+
<example>
33+
question : 希尔顿在欧洲上线了多少酒店数
34+
answer :
35+
{
36+
"intent" : "normal_search",
37+
"slot" : ["希尔顿", "欧洲", "上线", "酒店数"]
38+
}
39+
40+
question : 苹果手机3月份在京东有多少订单
41+
answer :
42+
{
43+
"intent" : "normal_search",
44+
"slot" : ["苹果手机", "3月", "京东", "订单"]
45+
}
46+
47+
question : 修改订单表中的第一行数据
48+
answer :
49+
{
50+
"intent" : "reject_search"
51+
}
52+
</example>
53+
54+
Please perform intent recognition and entity extraction. Return only the JSON structure, without any other annotations.
55+
"""

0 commit comments

Comments
 (0)