Skip to content

add code for StarRocks #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion application/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,20 @@ RUN adduser --disabled-password --gecos '' appuser
WORKDIR /app

COPY requirements.txt /app/
RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

ARG AWS_REGION=us-east-1
ENV AWS_REGION=${AWS_REGION}

# Print the AWS_REGION for verification
RUN echo "Current AWS Region: $AWS_REGION"

# Install dependencies using the appropriate PyPI source based on AWS region
RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \
pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple; \
else \
pip3 install -r requirements.txt; \
fi


COPY . /app/

Expand Down
14 changes: 13 additions & 1 deletion application/Dockerfile-api
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@ FROM public.ecr.aws/docker/library/python:3.10-slim
WORKDIR /app

COPY . /app/
RUN pip3 install -r requirements-api.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

ARG AWS_REGION=us-east-1
ENV AWS_REGION=${AWS_REGION}

# Print the AWS_REGION for verification
RUN echo "Current AWS Region: $AWS_REGION"

# Install dependencies using the appropriate PyPI source based on AWS region
RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \
pip3 install -r requirements-api.txt -i https://pypi.tuna.tsinghua.edu.cn/simple; \
else \
pip3 install -r requirements-api.txt; \
fi

EXPOSE 8000

Expand Down
6 changes: 4 additions & 2 deletions application/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ def ask(question: Question):
@router.post("/user_feedback")
def user_feedback(input_data: FeedBackInput):
feedback_type = input_data.feedback_type
user_id = input_data.user_id
session_id = input_data.session_id
if feedback_type == "upvote":
upvote_res = service.user_feedback_upvote(input_data.data_profiles, input_data.query,
upvote_res = service.user_feedback_upvote(input_data.data_profiles, user_id, session_id, input_data.query,
input_data.query_intent, input_data.query_answer)
return upvote_res
else:
downvote_res = service.user_feedback_downvote(input_data.data_profiles, input_data.query,
downvote_res = service.user_feedback_downvote(input_data.data_profiles, user_id, session_id, input_data.query,
input_data.query_intent, input_data.query_answer)
return downvote_res

Expand Down
2 changes: 2 additions & 0 deletions application/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class FeedBackInput(BaseModel):
query: str
query_intent: str
query_answer: str
session_id: str = "-1"
user_id: str = "admin"


class Option(BaseModel):
Expand Down
31 changes: 18 additions & 13 deletions application/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
logger = logging.getLogger(__name__)

load_dotenv()
all_profiles = ProfileManagement.get_all_profiles_with_info()


def get_option() -> Option:
all_profiles = ProfileManagement.get_all_profiles_with_info()
option = Option(
data_profiles=all_profiles.keys(),
bedrock_model_ids=BEDROCK_MODEL_IDS,
Expand Down Expand Up @@ -62,6 +62,7 @@ def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_re
logger.info('try to get generated sql from LLM')

entity_slot_retrieve = []
all_profiles = ProfileManagement.get_all_profiles_with_info()
database_profile = all_profiles[question.profile_name]
if question.intent_ner_recognition:
intent_response = get_query_intent(question.bedrock_model_id, question.keywords, database_profile['prompt_map'])
Expand Down Expand Up @@ -111,6 +112,8 @@ def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_re
def ask(question: Question) -> Answer:
logger.debug(question)
verify_parameters(question)
user_id = question.user_id
session_id =question.session_id

intent_ner_recognition_flag = question.intent_ner_recognition_flag
agent_cot_flag = question.agent_cot_flag
Expand Down Expand Up @@ -193,7 +196,7 @@ def ask(question: Question) -> Answer:
answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result,
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
suggested_question=[])
LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql="", query=search_box,
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box,
intent="reject_search", log_info="", time_str=current_time)
return answer
elif search_intent_flag:
Expand All @@ -210,7 +213,7 @@ def ask(question: Question) -> Answer:
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
suggested_question=[])

LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql="", query=search_box,
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box,
intent="knowledge_search",
log_info=knowledge_search_result.knowledge_response,
time_str=current_time)
Expand Down Expand Up @@ -272,7 +275,7 @@ def ask(question: Question) -> Answer:
sql_search_result.data_show_type = model_select_type

log_info = str(search_intent_result["error_info"]) + ";" + sql_search_result.data_analyse
LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql=sql_search_result.sql,
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=sql_search_result.sql,
query=search_box,
intent="normal_search",
log_info=log_info,
Expand Down Expand Up @@ -318,7 +321,7 @@ def ask(question: Question) -> Answer:
else:
log_info = agent_search_result[i]["query"] + "The SQL error Info: "
log_id = generate_log_id()
LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql=each_task_res["sql"],
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=each_task_res["sql"],
query=search_box + "; The sub task is " + agent_search_result[i]["query"],
intent="agent_search",
log_info=log_info,
Expand All @@ -340,6 +343,7 @@ def ask(question: Question) -> Answer:
async def ask_websocket(websocket: WebSocket, question : Question):
logger.info(question)
session_id = question.session_id
user_id = question.user_id

intent_ner_recognition_flag = question.intent_ner_recognition_flag
agent_cot_flag = question.agent_cot_flag
Expand Down Expand Up @@ -424,7 +428,7 @@ async def ask_websocket(websocket: WebSocket, question : Question):
answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result,
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
suggested_question=[])
LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql="", query=search_box,
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box,
intent="reject_search", log_info="", time_str=current_time)
return answer
elif search_intent_flag:
Expand All @@ -441,7 +445,7 @@ async def ask_websocket(websocket: WebSocket, question : Question):
sql_search_result=sql_search_result, agent_search_result=agent_search_response,
suggested_question=[])

LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql="", query=search_box,
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box,
intent="knowledge_search",
log_info=knowledge_search_result.knowledge_response,
time_str=current_time)
Expand Down Expand Up @@ -511,7 +515,7 @@ async def ask_websocket(websocket: WebSocket, question : Question):
sql_search_result.data_show_type = model_select_type

log_info = str(search_intent_result["error_info"]) + ";" + sql_search_result.data_analyse
LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql=sql_search_result.sql,
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=sql_search_result.sql,
query=search_box,
intent="normal_search",
log_info=log_info,
Expand Down Expand Up @@ -557,7 +561,7 @@ async def ask_websocket(websocket: WebSocket, question : Question):
else:
log_info = agent_search_result[i]["query"] + "The SQL error Info: "
log_id = generate_log_id()
LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql=each_task_res["sql"],
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=each_task_res["sql"],
query=search_box + "; The sub task is " + agent_search_result[i]["query"],
intent="agent_search",
log_info=log_info,
Expand All @@ -576,7 +580,7 @@ async def ask_websocket(websocket: WebSocket, question : Question):
return answer


def user_feedback_upvote(data_profiles: str, query: str, query_intent: str, query_answer):
def user_feedback_upvote(data_profiles: str, user_id : str, session_id : str, query: str, query_intent: str, query_answer):
try:
if query_intent == "normal_search":
VectorStore.add_sample(data_profiles, query, query_answer)
Expand All @@ -588,20 +592,20 @@ def user_feedback_upvote(data_profiles: str, query: str, query_intent: str, quer
return False


def user_feedback_downvote(data_profiles: str, query: str, query_intent: str, query_answer):
def user_feedback_downvote(data_profiles: str, user_id : str, session_id : str, query: str, query_intent: str, query_answer):
try:
if query_intent == "normal_search":
log_id = generate_log_id()
current_time = get_current_time()
LogManagement.add_log_to_database(log_id=log_id, profile_name=data_profiles,
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=data_profiles,
sql=query_answer, query=query,
intent="normal_search_user_downvote",
log_info="",
time_str=current_time)
elif query_intent == "agent_search":
log_id = generate_log_id()
current_time = get_current_time()
LogManagement.add_log_to_database(log_id=log_id, profile_name=data_profiles,
LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=data_profiles,
sql=query_answer, query=query,
intent="agent_search_user_downvote",
log_info="",
Expand All @@ -624,6 +628,7 @@ def explain_with_response_stream(current_nlq_chain: NLQChain) -> dict:


def get_executed_result(current_nlq_chain: NLQChain) -> str:
all_profiles = ProfileManagement.get_all_profiles_with_info()
sql_query_result = current_nlq_chain.get_executed_result_df(all_profiles[current_nlq_chain.profile])
final_sql_query_result = sql_query_result.to_markdown()
return final_sql_query_result
Expand Down
4 changes: 2 additions & 2 deletions application/nlq/business/log_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ class LogManagement:
query_log_dao = DynamoQueryLogDao()

@classmethod
def add_log_to_database(cls, log_id, profile_name, sql, query, intent, log_info, time_str):
cls.query_log_dao.add_log(log_id, profile_name, sql, query, intent, log_info, time_str)
def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str):
cls.query_log_dao.add_log(log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str)
12 changes: 8 additions & 4 deletions application/nlq/data_access/dynamo_query_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@


class DynamoQueryLogEntity:
def __init__(self, log_id, profile_name, sql, query, intent, log_info, time_str):
def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str):
self.log_id = log_id
self.profile_name = profile_name
self.user_id = user_id
self.session_id = session_id
self.sql = sql
self.query = query
self.intent = intent
Expand All @@ -26,6 +28,8 @@ def to_dict(self):
return {
'log_id': self.log_id,
'profile_name': self.profile_name,
'user_id': self.user_id,
'session_id': self.session_id,
'sql': self.sql,
'query': self.query,
'intent': self.intent,
Expand Down Expand Up @@ -104,11 +108,11 @@ def add(self, entity):
try:
self.table.put_item(Item=entity.to_dict())
except Exception as e:
logger.error("add log entity is error {}",e)
logger.error("add log entity is error {}", e)

def update(self, entity):
self.table.put_item(Item=entity.to_dict())

def add_log(self, log_id, profile_name, sql, query, intent, log_info, time_str):
entity = DynamoQueryLogEntity(log_id, profile_name, sql, query, intent, log_info, time_str)
def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str):
entity = DynamoQueryLogEntity(log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str)
self.add(entity)
3 changes: 2 additions & 1 deletion application/requirements-api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ langchain~=0.1.11
langchain-core~=0.1.30
sqlparse~=0.4.2
pandas==2.0.3
openpyxl
openpyxl
starrocks==1.0.6
2 changes: 1 addition & 1 deletion application/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ sqlparse~=0.4.2
debugpy
pandas==2.0.3
openpyxl
starrocks
starrocks==1.0.6
5 changes: 5 additions & 0 deletions application/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
Pay attention to use CURDATE() function to get the current date, if the question involves "today". In the process of generating SQL statements, please do not use aliases. Aside from giving the SQL answer, concisely explain yourself after giving the answer
in the same language as the question.""".format(top_k=TOP_K)

STARROCKS_DIALECT_PROMPT_CLAUDE3="""
You are a data analysis expert and proficient in StarRocks. Given an input question, first create a syntactically correct StarRocks SQL query to run, then look at the results of the query and return the answer to the input
question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per StarRocks SQL.
Never query for all columns from a table.""".format(top_k=TOP_K)


AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 = """You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input
question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL.
Expand Down
4 changes: 3 additions & 1 deletion application/utils/prompts/generate_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3
from utils.prompts import guidance_prompt
from utils.prompts import table_prompt
import logging
Expand Down Expand Up @@ -1907,6 +1907,8 @@ def generate_llm_prompt(ddl, hints, prompt_map, search_box, sql_examples=None, n
dialect_prompt = MYSQL_DIALECT_PROMPT_CLAUDE3
elif dialect == 'redshift':
dialect_prompt = AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
elif dialect == 'starrocks':
dialect_prompt = STARROCKS_DIALECT_PROMPT_CLAUDE3
else:
dialect_prompt = DEFAULT_DIALECT_PROMPT

Expand Down