|
1 | 1 | import json |
2 | | -from typing import Union |
3 | 2 | from dotenv import load_dotenv |
4 | 3 |
|
5 | 4 | from nlq.business.connection import ConnectionManagement |
6 | | -from nlq.business.datasource.factory import DataSourceFactory |
7 | 5 | from nlq.business.log_feedback import FeedBackManagement |
8 | 6 | from nlq.business.model import ModelManagement |
9 | 7 | from nlq.business.nlq_chain import NLQChain |
|
13 | 11 | from nlq.core.chat_context import ProcessingContext |
14 | 12 | from nlq.core.state import QueryState |
15 | 13 | from nlq.core.state_machine import QueryStateMachine |
16 | | -from utils.database import get_db_url_dialect |
17 | | -from utils.domain import SearchTextSqlResult |
18 | | -from utils.llm import text_to_sql, get_query_intent |
19 | 14 | from utils.logging import getLogger |
20 | | -from utils.opensearch import get_retrieve_opensearch |
21 | 15 | from utils.env_var import opensearch_info |
22 | | -from utils.tool import generate_log_id, get_current_time, get_generated_sql, serialize_timestamp |
| 16 | +from utils.tool import generate_log_id, get_current_time, serialize_timestamp |
23 | 17 | from .schemas import Question, Example, Option, Message, HistoryMessage |
24 | 18 | from .exception_handler import BizException |
25 | 19 | from utils.constant import BEDROCK_MODEL_IDS |
@@ -83,47 +77,6 @@ def get_history_by_user_profile(user_id: str, profile_name: str): |
83 | 77 | return chat_history |
84 | 78 |
|
85 | 79 |
|
86 | | -def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_response_stream=False) -> Union[ |
87 | | - str, dict]: |
88 | | - logger.info('try to get generated sql from LLM') |
89 | | - |
90 | | - entity_slot_retrieve = [] |
91 | | - all_profiles = ProfileManagement.get_all_profiles_with_info() |
92 | | - database_profile = all_profiles[question.profile_name] |
93 | | - if question.intent_ner_recognition: |
94 | | - intent_response = get_query_intent(question.bedrock_model_id, question.keywords, database_profile['prompt_map']) |
95 | | - intent = intent_response.get("intent", "normal_search") |
96 | | - if intent == "reject_search": |
97 | | - raise BizException(ErrorEnum.NOT_SUPPORTED) |
98 | | - entity_slot = intent_response.get("slot", []) |
99 | | - if entity_slot: |
100 | | - for each_entity in entity_slot: |
101 | | - entity_retrieve = get_retrieve_opensearch(opensearch_info, each_entity, "ner", question.profile_name, 1, |
102 | | - 0.7) |
103 | | - if entity_retrieve: |
104 | | - entity_slot_retrieve.extend(entity_retrieve) |
105 | | - |
106 | | - # Whether Retrieving Few Shots from Database |
107 | | - logger.info('Sending request...') |
108 | | - # fix db url is Empty |
109 | | - if database_profile['db_url'] == '': |
110 | | - conn_name = database_profile['conn_name'] |
111 | | - db_url = ConnectionManagement.get_db_url_by_name(conn_name) |
112 | | - database_profile['db_url'] = db_url |
113 | | - |
114 | | - response = text_to_sql(database_profile['tables_info'], |
115 | | - database_profile['hints'], |
116 | | - database_profile['prompt_map'], |
117 | | - question.keywords, |
118 | | - model_id=question.bedrock_model_id, |
119 | | - sql_examples=current_nlq_chain.get_retrieve_samples(), |
120 | | - ner_example=entity_slot_retrieve, |
121 | | - dialect=get_db_url_dialect(database_profile['db_url']), |
122 | | - model_provider=None, |
123 | | - with_response_stream=with_response_stream, ) |
124 | | - return response |
125 | | - |
126 | | - |
127 | 80 | async def ask_websocket(websocket: WebSocket, question: Question): |
128 | 81 | logger.info(question) |
129 | 82 | session_id = question.session_id |
@@ -332,134 +285,6 @@ def user_feedback_downvote(data_profiles: str, user_id: str, session_id: str, qu |
332 | 285 | return False |
333 | 286 |
|
334 | 287 |
|
335 | | -def ask_with_response_stream(question: Question, current_nlq_chain: NLQChain) -> dict: |
336 | | - logger.info('try to get generated sql from LLM') |
337 | | - response = get_result_from_llm(question, current_nlq_chain, True) |
338 | | - logger.info("got llm response") |
339 | | - return response |
340 | | - |
341 | | - |
342 | | -def get_executed_result(current_nlq_chain: NLQChain) -> str: |
343 | | - all_profiles = ProfileManagement.get_all_profiles_with_info() |
344 | | - sql_query_result = current_nlq_chain.get_executed_result_df(all_profiles[current_nlq_chain.profile]) |
345 | | - final_sql_query_result = sql_query_result.to_markdown() |
346 | | - return final_sql_query_result |
347 | | - |
348 | | - |
349 | | -async def normal_text_search_websocket(websocket: WebSocket, session_id: str, search_box, model_type, database_profile, |
350 | | - entity_slot, opensearch_info, selected_profile, use_rag, user_id, |
351 | | - model_provider=None, username=None): |
352 | | - entity_slot_retrieve = [] |
353 | | - retrieve_result = [] |
354 | | - response = "" |
355 | | - sql = "" |
356 | | - search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve, |
357 | | - retrieve_result=retrieve_result, response=response, sql=sql) |
358 | | - try: |
359 | | - if database_profile['db_url'] == '': |
360 | | - conn_name = database_profile['conn_name'] |
361 | | - db_url = ConnectionManagement.get_db_url_by_name(conn_name) |
362 | | - database_profile['db_url'] = db_url |
363 | | - # TODO: db_type already set in profile |
364 | | - database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name) |
365 | | - |
366 | | - if len(entity_slot) > 0 and use_rag: |
367 | | - await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "start", |
368 | | - user_id) |
369 | | - for each_entity in entity_slot: |
370 | | - entity_retrieve = get_retrieve_opensearch(opensearch_info, each_entity, "ner", |
371 | | - selected_profile, 1, 0.7) |
372 | | - if len(entity_retrieve) > 0: |
373 | | - entity_slot_retrieve.extend(entity_retrieve) |
374 | | - await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "end", user_id) |
375 | | - |
376 | | - if use_rag: |
377 | | - await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "start", user_id) |
378 | | - retrieve_result = get_retrieve_opensearch(opensearch_info, search_box, "query", |
379 | | - selected_profile, 3, 0.5) |
380 | | - await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "end", user_id) |
381 | | - |
382 | | - await response_websocket(websocket, session_id, "Generating SQL", ContentEnum.STATE, "start", user_id) |
383 | | - |
384 | | - response = text_to_sql(database_profile['tables_info'], |
385 | | - database_profile['hints'], |
386 | | - database_profile['prompt_map'], |
387 | | - search_box, |
388 | | - model_id=model_type, |
389 | | - sql_examples=retrieve_result, |
390 | | - ner_example=entity_slot_retrieve, |
391 | | - dialect=database_profile['db_type'], |
392 | | - model_provider=model_provider) |
393 | | - logger.info(f'{response=}') |
394 | | - await response_websocket(websocket, session_id, "Generating SQL", ContentEnum.STATE, "end", user_id) |
395 | | - sql = get_generated_sql(response) |
396 | | - # post-processing the sql for row level security |
397 | | - post_sql = DataSourceFactory.apply_row_level_security_for_sql( |
398 | | - database_profile['db_type'], |
399 | | - sql, |
400 | | - database_profile['row_level_security_config'], |
401 | | - username |
402 | | - ) |
403 | | - |
404 | | - search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve, |
405 | | - retrieve_result=retrieve_result, response=response, sql="") |
406 | | - search_result.entity_slot_retrieve = entity_slot_retrieve |
407 | | - search_result.retrieve_result = retrieve_result |
408 | | - search_result.response = response |
409 | | - search_result.sql = post_sql |
410 | | - search_result.original_sql = sql |
411 | | - except Exception as e: |
412 | | - logger.exception(e) |
413 | | - return search_result |
414 | | - |
415 | | - |
416 | | -async def normal_sql_regenerating_websocket(websocket: WebSocket, session_id: str, search_box, model_type, |
417 | | - database_profile, entity_slot_retrieve, retrieve_result, additional_info, |
418 | | - username: str): |
419 | | - entity_slot_retrieve = entity_slot_retrieve |
420 | | - retrieve_result = retrieve_result |
421 | | - response = "" |
422 | | - sql = "" |
423 | | - search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve, |
424 | | - retrieve_result=retrieve_result, response=response, sql=sql) |
425 | | - try: |
426 | | - if database_profile['db_url'] == '': |
427 | | - conn_name = database_profile['conn_name'] |
428 | | - db_url = ConnectionManagement.get_db_url_by_name(conn_name) |
429 | | - database_profile['db_url'] = db_url |
430 | | - database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name) |
431 | | - |
432 | | - response = text_to_sql(database_profile['tables_info'], |
433 | | - database_profile['hints'], |
434 | | - database_profile['prompt_map'], |
435 | | - search_box, |
436 | | - model_id=model_type, |
437 | | - sql_examples=retrieve_result, |
438 | | - ner_example=entity_slot_retrieve, |
439 | | - dialect=database_profile['db_type'], |
440 | | - model_provider=None, |
441 | | - additional_info=additional_info) |
442 | | - logger.info("normal_sql_regenerating_websocket") |
443 | | - logger.info(f'{response=}') |
444 | | - sql = get_generated_sql(response) |
445 | | - post_sql = DataSourceFactory.apply_row_level_security_for_sql( |
446 | | - database_profile['db_type'], |
447 | | - sql, |
448 | | - database_profile['row_level_security_config'], |
449 | | - username |
450 | | - ) |
451 | | - search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve, |
452 | | - retrieve_result=retrieve_result, response=response, sql="") |
453 | | - search_result.entity_slot_retrieve = entity_slot_retrieve |
454 | | - search_result.retrieve_result = retrieve_result |
455 | | - search_result.response = response |
456 | | - search_result.sql = post_sql |
457 | | - search_result.original_sql = sql |
458 | | - except Exception as e: |
459 | | - logger.error(e) |
460 | | - return search_result |
461 | | - |
462 | | - |
463 | 288 | async def response_websocket(websocket: WebSocket, session_id: str, content, |
464 | 289 | content_type: ContentEnum = ContentEnum.COMMON, status: str = "-1", |
465 | 290 | user_id: str = "admin"): |
|
0 commit comments