Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions application/.env.cntemplate
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ BEDROCK_REGION=cn-north-1
RDS_REGION_NAME=cn-north-1
AWS_DEFAULT_REGION=cn-north-1

DYNAMODB_AWS_REGION=cn-north-1

SAGEMAKER_ENDPOINT_EMBEDDING=embedding-bge-m3-3ab71
SAGEMAKER_ENDPOINT_INTENT=llm-internlm2-chat-7b-3ab71
SAGEMAKER_ENDPOINT_SQL=sql-sqlcoder-7b-2-7e5b6
Expand Down
11 changes: 9 additions & 2 deletions application/nlq/business/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

logger = logging.getLogger(__name__)


class ConnectionManagement:
connection_config_dao = ConnectConfigDao()

Expand All @@ -23,7 +24,8 @@ def get_conn_config_by_name(cls, conn_name):

@classmethod
def update_connection(cls, conn_name, db_type, db_host, db_port, db_user, db_pwd, db_name, comment):
cls.connection_config_dao.update_db_info(conn_name, db_type, db_host, db_port, db_user, db_pwd, db_name, comment)
cls.connection_config_dao.update_db_info(conn_name, db_type, db_host, db_port, db_user, db_pwd, db_name,
comment)
logger.info(f"Connection {conn_name} updated")

@classmethod
Expand All @@ -48,4 +50,9 @@ def get_table_definition_by_config(cls, conn_config: ConnectConfigEntity, schema
@classmethod
def get_db_url_by_name(cls, conn_name):
conn_config = cls.get_conn_config_by_name(conn_name)
return RelationDatabase.get_db_url_by_connection(conn_config)
return RelationDatabase.get_db_url_by_connection(conn_config)

@classmethod
def get_db_type_by_name(cls, conn_name):
conn_config = cls.get_conn_config_by_name(conn_name)
return conn_config.db_type
2 changes: 1 addition & 1 deletion application/nlq/data_access/dynamo_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# DynamoDB table name
CONNECT_CONFIG_TABLE_NAME = 'NlqConnectConfig'
DYNAMODB_AWS_REGION = os.environ.get('DYNAMODB_AWS_REGION', 'us-west-2')
DYNAMODB_AWS_REGION = os.environ.get('DYNAMODB_AWS_REGION')

class ConnectConfigEntity:
"""Connect config entity mapped to DynamoDB item"""
Expand Down
6 changes: 3 additions & 3 deletions application/pages/1_🌍_Natural_Language_Querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def main():
conn_name = database_profile['conn_name']
db_url = ConnectionManagement.get_db_url_by_name(conn_name)
database_profile['db_url'] = db_url
database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name)

if intent_ner_recognition:
intent_response = get_query_intent(model_type, search_box)
Expand Down Expand Up @@ -384,8 +385,7 @@ def main():
model_id=model_type,
sql_examples=retrieve_result,
ner_example=entity_slot_retrieve,
dialect=get_db_url_dialect(
database_profile['db_url']),
dialect=database_profile['db_type'],
model_provider=model_provider)
sql_str = get_response_sql(each_task_sql_query)
each_res_dict["sql"] = sql_str
Expand Down Expand Up @@ -440,7 +440,7 @@ def main():
model_id=model_type,
sql_examples=retrieve_result,
ner_example=entity_slot_retrieve,
dialect=get_db_url_dialect(database_profile['db_url']),
dialect=database_profile['db_type'],
model_provider=model_provider)

logger.info(f'got llm response: {response}')
Expand Down
6 changes: 2 additions & 4 deletions application/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from utils import opensearch
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
DEFAULT_DIALECT_PROMPT, SEARCH_INTENT_PROMPT_CLAUDE3, CLAUDE3_DATA_ANALYSE_SYSTEM_PROMPT, \
CLAUDE3_DATA_ANALYSE_USER_PROMPT
CLAUDE3_DATA_ANALYSE_USER_PROMPT, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
import os
import logging
from langchain_core.output_parsers import JsonOutputParser
Expand Down Expand Up @@ -182,9 +182,7 @@ def generate_prompt(ddl, hints, search_box, sql_examples=None, ner_example=None,
elif dialect == 'mysql':
dialect_prompt = MYSQL_DIALECT_PROMPT_CLAUDE3
elif dialect == 'redshift':
dialect_prompt = '''You are a Amazon Redshift expert. Given an input question, first create a syntactically
correct Redshift query to run, then look at the results of the query and return the answer to the input
question.'''
dialect_prompt = AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
else:
dialect_prompt = DEFAULT_DIALECT_PROMPT

Expand Down
5 changes: 5 additions & 0 deletions application/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
Pay attention to use CURDATE() function to get the current date, if the question involves "today". Aside from giving the SQL answer, concisely explain yourself after giving the answer
in the same language as the question.""".format(top_k=TOP_K)


AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 = """You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input
question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL.
Never query for all columns from a table.""".format(top_k=TOP_K)

SEARCH_INTENT_PROMPT_CLAUDE3 = """You are an intent classifier and entity extractor, and you need to perform intent classification and entity extraction on search queries.
Background: I want to query data in the database, and you need to help me determine the user's relevant intent and extract the keywords from the query statement. Finally, return a JSON structure.

Expand Down
6 changes: 2 additions & 4 deletions application/utils/prompts/generate_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
DEFAULT_DIALECT_PROMPT, AGENT_COT_SYSTEM_PROMPT, AGENT_COT_EXAMPLE
DEFAULT_DIALECT_PROMPT, AGENT_COT_SYSTEM_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
from utils.prompts import guidance_prompt
from utils.prompts import table_prompt
import logging
Expand Down Expand Up @@ -285,9 +285,7 @@ def generate_llm_prompt(ddl, hints, search_box, sql_examples=None, ner_example=N
elif dialect == 'mysql':
dialect_prompt = MYSQL_DIALECT_PROMPT_CLAUDE3
elif dialect == 'redshift':
dialect_prompt = '''You are a Amazon Redshift expert. Given an input question, first create a syntactically
correct Redshift query to run, then look at the results of the query and return the answer to the input
question. query for at most 100 results using the LIMIT. '''
dialect_prompt = AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
else:
dialect_prompt = DEFAULT_DIALECT_PROMPT

Expand Down