Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions application/utils/prompts/generate_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
"mistral.mixtral-8x7b-instruct-v0:1": "mixtral-8x7b-instruct-0"
}

system_prompt_dict = {}
user_prompt_dict = {}

system_prompt_dict['mixtral-8x7b-instruct-0'] = """
user_prompt_dict['mixtral-8x7b-instruct-0'] = """
{dialect_prompt}

Assume a database with the following tables and columns exists:
Expand Down Expand Up @@ -55,9 +55,11 @@

Think about your answer first before you respond. Put your sql in <sql></sql> tags.

The question is : {question}

"""

system_prompt_dict['haiku-20240307v1-0'] = """
user_prompt_dict['haiku-20240307v1-0'] = """
{dialect_prompt}

Assume a database with the following tables and columns exists:
Expand Down Expand Up @@ -98,9 +100,11 @@

Think about your answer first before you respond. Put your sql in <sql></sql> tags.

The question is : {question}

"""

system_prompt_dict['sonnet-20240229v1-0'] = """
user_prompt_dict['sonnet-20240229v1-0'] = """
{dialect_prompt}

Assume a database with the following tables and columns exists:
Expand Down Expand Up @@ -141,6 +145,22 @@

Think about your answer first before you respond. Put your sql in <sql></sql> tags.

The question is : {question}

"""

system_prompt_dict = {}

system_prompt_dict['mixtral-8x7b-instruct-0'] = """
You are a data analysis expert and proficient in {dialect}.
"""

system_prompt_dict['haiku-20240307v1-0'] = """
You are a data analysis expert and proficient in {dialect}.
"""

system_prompt_dict['sonnet-20240229v1-0'] = """
You are a data analysis expert and proficient in {dialect}.
"""


Expand All @@ -152,6 +172,14 @@ def get_variable(self, name):
return self.variable_map.get(name)


class UserPromptMapper:
def __init__(self):
self.variable_map = user_prompt_dict

def get_variable(self, name):
return self.variable_map.get(name)


def generate_create_table_ddl(table_description):
lines = table_description.strip().split('\n')
table_name = lines[0].split(':')[0].strip()
Expand Down Expand Up @@ -184,6 +212,7 @@ def generate_create_table_ddl(table_description):


system_prompt_mapper = SystemPromptMapper()
user_prompt_mapper = UserPromptMapper()
table_prompt_mapper = table_prompt.TablePromptMapper()
guidance_prompt_mapper = guidance_prompt.GuidancePromptMapper()

Expand All @@ -209,7 +238,7 @@ def generate_llm_prompt(ddl, hints, search_box, sql_examples=None, ner_example=N
elif dialect == 'redshift':
dialect_prompt = '''You are a Amazon Redshift expert. Given an input question, first create a syntactically
correct Redshift query to run, then look at the results of the query and return the answer to the input
question.'''
question. query for at most 100 results using the LIMIT. '''
else:
dialect_prompt = DEFAULT_DIALECT_PROMPT

Expand All @@ -227,17 +256,21 @@ def generate_llm_prompt(ddl, hints, search_box, sql_examples=None, ner_example=N

name = support_model_ids_map[model_id]
system_prompt = system_prompt_mapper.get_variable(name)
user_prompt = user_prompt_mapper.get_variable(name)
if long_string == '':
table_prompt = table_prompt_mapper.get_variable(name)
else:
table_prompt = long_string
guidance_prompt = guidance_prompt_mapper.get_variable(name)

system_prompt = system_prompt.format(dialect_prompt=dialect_prompt, sql_schema=table_prompt,
sql_guidance=guidance_prompt, examples=example_sql_prompt,
ner_info=example_ner_prompt)
if dialect == "redshift":
system_prompt = system_prompt.format(dialect="Amazon Redshift")
else:
system_prompt = system_prompt.format(dialect=dialect)

user_prompt = search_box
user_prompt = user_prompt.format(dialect_prompt=dialect_prompt, sql_schema=table_prompt,
sql_guidance=guidance_prompt, examples=example_sql_prompt,
ner_info=example_ner_prompt, question=search_box)

return user_prompt, system_prompt

Expand Down