Skip to content

Commit 8081be8

Browse files
authored
Merge pull request #315 from aws-samples/v1.8.0_dev
remove some code
2 parents 5375bb5 + 9e09a79 commit 8081be8

File tree

1 file changed

+1
-176
lines changed

1 file changed

+1
-176
lines changed

application/api/service.py

Lines changed: 1 addition & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import json
2-
from typing import Union
32
from dotenv import load_dotenv
43

54
from nlq.business.connection import ConnectionManagement
6-
from nlq.business.datasource.factory import DataSourceFactory
75
from nlq.business.log_feedback import FeedBackManagement
86
from nlq.business.model import ModelManagement
97
from nlq.business.nlq_chain import NLQChain
@@ -13,13 +11,9 @@
1311
from nlq.core.chat_context import ProcessingContext
1412
from nlq.core.state import QueryState
1513
from nlq.core.state_machine import QueryStateMachine
16-
from utils.database import get_db_url_dialect
17-
from utils.domain import SearchTextSqlResult
18-
from utils.llm import text_to_sql, get_query_intent
1914
from utils.logging import getLogger
20-
from utils.opensearch import get_retrieve_opensearch
2115
from utils.env_var import opensearch_info
22-
from utils.tool import generate_log_id, get_current_time, get_generated_sql, serialize_timestamp
16+
from utils.tool import generate_log_id, get_current_time, serialize_timestamp
2317
from .schemas import Question, Example, Option, Message, HistoryMessage
2418
from .exception_handler import BizException
2519
from utils.constant import BEDROCK_MODEL_IDS
@@ -83,47 +77,6 @@ def get_history_by_user_profile(user_id: str, profile_name: str):
8377
return chat_history
8478

8579

86-
def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_response_stream=False) -> Union[
87-
str, dict]:
88-
logger.info('try to get generated sql from LLM')
89-
90-
entity_slot_retrieve = []
91-
all_profiles = ProfileManagement.get_all_profiles_with_info()
92-
database_profile = all_profiles[question.profile_name]
93-
if question.intent_ner_recognition:
94-
intent_response = get_query_intent(question.bedrock_model_id, question.keywords, database_profile['prompt_map'])
95-
intent = intent_response.get("intent", "normal_search")
96-
if intent == "reject_search":
97-
raise BizException(ErrorEnum.NOT_SUPPORTED)
98-
entity_slot = intent_response.get("slot", [])
99-
if entity_slot:
100-
for each_entity in entity_slot:
101-
entity_retrieve = get_retrieve_opensearch(opensearch_info, each_entity, "ner", question.profile_name, 1,
102-
0.7)
103-
if entity_retrieve:
104-
entity_slot_retrieve.extend(entity_retrieve)
105-
106-
# Whether Retrieving Few Shots from Database
107-
logger.info('Sending request...')
108-
# fix db url is Empty
109-
if database_profile['db_url'] == '':
110-
conn_name = database_profile['conn_name']
111-
db_url = ConnectionManagement.get_db_url_by_name(conn_name)
112-
database_profile['db_url'] = db_url
113-
114-
response = text_to_sql(database_profile['tables_info'],
115-
database_profile['hints'],
116-
database_profile['prompt_map'],
117-
question.keywords,
118-
model_id=question.bedrock_model_id,
119-
sql_examples=current_nlq_chain.get_retrieve_samples(),
120-
ner_example=entity_slot_retrieve,
121-
dialect=get_db_url_dialect(database_profile['db_url']),
122-
model_provider=None,
123-
with_response_stream=with_response_stream, )
124-
return response
125-
126-
12780
async def ask_websocket(websocket: WebSocket, question: Question):
12881
logger.info(question)
12982
session_id = question.session_id
@@ -332,134 +285,6 @@ def user_feedback_downvote(data_profiles: str, user_id: str, session_id: str, qu
332285
return False
333286

334287

335-
def ask_with_response_stream(question: Question, current_nlq_chain: NLQChain) -> dict:
336-
logger.info('try to get generated sql from LLM')
337-
response = get_result_from_llm(question, current_nlq_chain, True)
338-
logger.info("got llm response")
339-
return response
340-
341-
342-
def get_executed_result(current_nlq_chain: NLQChain) -> str:
343-
all_profiles = ProfileManagement.get_all_profiles_with_info()
344-
sql_query_result = current_nlq_chain.get_executed_result_df(all_profiles[current_nlq_chain.profile])
345-
final_sql_query_result = sql_query_result.to_markdown()
346-
return final_sql_query_result
347-
348-
349-
async def normal_text_search_websocket(websocket: WebSocket, session_id: str, search_box, model_type, database_profile,
350-
entity_slot, opensearch_info, selected_profile, use_rag, user_id,
351-
model_provider=None, username=None):
352-
entity_slot_retrieve = []
353-
retrieve_result = []
354-
response = ""
355-
sql = ""
356-
search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve,
357-
retrieve_result=retrieve_result, response=response, sql=sql)
358-
try:
359-
if database_profile['db_url'] == '':
360-
conn_name = database_profile['conn_name']
361-
db_url = ConnectionManagement.get_db_url_by_name(conn_name)
362-
database_profile['db_url'] = db_url
363-
# TODO: db_type already set in profile
364-
database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name)
365-
366-
if len(entity_slot) > 0 and use_rag:
367-
await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "start",
368-
user_id)
369-
for each_entity in entity_slot:
370-
entity_retrieve = get_retrieve_opensearch(opensearch_info, each_entity, "ner",
371-
selected_profile, 1, 0.7)
372-
if len(entity_retrieve) > 0:
373-
entity_slot_retrieve.extend(entity_retrieve)
374-
await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "end", user_id)
375-
376-
if use_rag:
377-
await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "start", user_id)
378-
retrieve_result = get_retrieve_opensearch(opensearch_info, search_box, "query",
379-
selected_profile, 3, 0.5)
380-
await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "end", user_id)
381-
382-
await response_websocket(websocket, session_id, "Generating SQL", ContentEnum.STATE, "start", user_id)
383-
384-
response = text_to_sql(database_profile['tables_info'],
385-
database_profile['hints'],
386-
database_profile['prompt_map'],
387-
search_box,
388-
model_id=model_type,
389-
sql_examples=retrieve_result,
390-
ner_example=entity_slot_retrieve,
391-
dialect=database_profile['db_type'],
392-
model_provider=model_provider)
393-
logger.info(f'{response=}')
394-
await response_websocket(websocket, session_id, "Generating SQL", ContentEnum.STATE, "end", user_id)
395-
sql = get_generated_sql(response)
396-
# post-processing the sql for row level security
397-
post_sql = DataSourceFactory.apply_row_level_security_for_sql(
398-
database_profile['db_type'],
399-
sql,
400-
database_profile['row_level_security_config'],
401-
username
402-
)
403-
404-
search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve,
405-
retrieve_result=retrieve_result, response=response, sql="")
406-
search_result.entity_slot_retrieve = entity_slot_retrieve
407-
search_result.retrieve_result = retrieve_result
408-
search_result.response = response
409-
search_result.sql = post_sql
410-
search_result.original_sql = sql
411-
except Exception as e:
412-
logger.exception(e)
413-
return search_result
414-
415-
416-
async def normal_sql_regenerating_websocket(websocket: WebSocket, session_id: str, search_box, model_type,
417-
database_profile, entity_slot_retrieve, retrieve_result, additional_info,
418-
username: str):
419-
entity_slot_retrieve = entity_slot_retrieve
420-
retrieve_result = retrieve_result
421-
response = ""
422-
sql = ""
423-
search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve,
424-
retrieve_result=retrieve_result, response=response, sql=sql)
425-
try:
426-
if database_profile['db_url'] == '':
427-
conn_name = database_profile['conn_name']
428-
db_url = ConnectionManagement.get_db_url_by_name(conn_name)
429-
database_profile['db_url'] = db_url
430-
database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name)
431-
432-
response = text_to_sql(database_profile['tables_info'],
433-
database_profile['hints'],
434-
database_profile['prompt_map'],
435-
search_box,
436-
model_id=model_type,
437-
sql_examples=retrieve_result,
438-
ner_example=entity_slot_retrieve,
439-
dialect=database_profile['db_type'],
440-
model_provider=None,
441-
additional_info=additional_info)
442-
logger.info("normal_sql_regenerating_websocket")
443-
logger.info(f'{response=}')
444-
sql = get_generated_sql(response)
445-
post_sql = DataSourceFactory.apply_row_level_security_for_sql(
446-
database_profile['db_type'],
447-
sql,
448-
database_profile['row_level_security_config'],
449-
username
450-
)
451-
search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve,
452-
retrieve_result=retrieve_result, response=response, sql="")
453-
search_result.entity_slot_retrieve = entity_slot_retrieve
454-
search_result.retrieve_result = retrieve_result
455-
search_result.response = response
456-
search_result.sql = post_sql
457-
search_result.original_sql = sql
458-
except Exception as e:
459-
logger.error(e)
460-
return search_result
461-
462-
463288
async def response_websocket(websocket: WebSocket, session_id: str, content,
464289
content_type: ContentEnum = ContentEnum.COMMON, status: str = "-1",
465290
user_id: str = "admin"):

0 commit comments

Comments
 (0)