diff --git a/application/nlq/business/profile.py b/application/nlq/business/profile.py index b611703..317fdad 100644 --- a/application/nlq/business/profile.py +++ b/application/nlq/business/profile.py @@ -66,3 +66,8 @@ def update_table_def(cls, profile_name, tables_info, merge_before_update=False): cls.profile_config_dao.update_table_def(profile_name, tables_info) logger.info(f"Table definition updated") + + @classmethod + def update_table_prompt(cls, profile_name, system_prompt, user_prompt): + cls.profile_config_dao.update_table_prompt(profile_name, system_prompt, user_prompt) + logger.info(f"System and user prompt updated") diff --git a/application/nlq/data_access/dynamo_profile.py b/application/nlq/data_access/dynamo_profile.py index ff4d55c..2182595 100644 --- a/application/nlq/data_access/dynamo_profile.py +++ b/application/nlq/data_access/dynamo_profile.py @@ -3,6 +3,7 @@ from typing import List from boto3.dynamodb.conditions import Key, Attr from botocore.exceptions import ClientError +from utils.prompts.generate_prompt import system_prompt_dict, user_prompt_dict logger = logging.getLogger(__name__) @@ -12,13 +13,17 @@ class ProfileConfigEntity: - def __init__(self, profile_name: str, conn_name: str, schemas: List[str], tables: List[str], comments: str, tables_info: dict=None): + def __init__(self, profile_name: str, conn_name: str, schemas: List[str], tables: List[str], comments: str, + tables_info: dict = None, system_prompt: dict = system_prompt_dict, + user_prompt: dict = user_prompt_dict): self.profile_name = profile_name self.conn_name = conn_name self.schemas = schemas self.tables = tables self.comments = comments self.tables_info = tables_info + self.system_prompt = system_prompt + self.user_prompt = user_prompt def to_dict(self): """Convert to DynamoDB item format""" @@ -27,7 +32,9 @@ def to_dict(self): 'profile_name': self.profile_name, 'schemas': self.schemas, 'tables': self.tables, - 'comments': self.comments + 'comments': self.comments, + 'system_prompt': self.system_prompt, + 'user_prompt': self.user_prompt } if self.tables_info: base_props['tables_info'] = self.tables_info @@ -137,4 +144,24 @@ def update_table_def(self, profile_name, tables_info): ) raise else: - return response["Attributes"] \ No newline at end of file + return response["Attributes"] + + def update_table_prompt(self, profile_name, system_prompt, user_prompt): + try: + response = self.table.update_item( + Key={"profile_name": profile_name}, + UpdateExpression="set system_prompt=:sp, user_prompt=:up", + ExpressionAttributeValues={":sp": system_prompt, ":up": user_prompt}, + ReturnValues="UPDATED_NEW", + ) + except ClientError as err: + logger.error( + "Couldn't update profile %s in table %s. Here's why: %s: %s", + profile_name, + self.table.name, + err.response["Error"]["Code"], + err.response["Error"]["Message"], + ) + raise + else: + return response["Attributes"] diff --git "a/application/pages/5_\360\237\252\231_Prompt_Management.py" "b/application/pages/5_\360\237\252\231_Prompt_Management.py" new file mode 100644 index 0000000..bb0fb88 --- /dev/null +++ "b/application/pages/5_\360\237\252\231_Prompt_Management.py" @@ -0,0 +1,45 @@ +import streamlit as st +from dotenv import load_dotenv +import logging +from nlq.business.profile import ProfileManagement +from utils.navigation import make_sidebar + +logger = logging.getLogger(__name__) + + +def main(): + load_dotenv() + logger.info('start prompt management') + st.set_page_config(page_title="Prompt Management") + make_sidebar() + + with st.sidebar: + st.title("Prompt Management") + current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(), + index=None, + placeholder="Please select data profile...", key='current_profile_name') + + if current_profile is not None: + profile_detail = ProfileManagement.get_profile_by_name(current_profile) + + system_prompt = profile_detail.system_prompt + user_prompt = profile_detail.user_prompt + if system_prompt is not None and user_prompt is not None: + model_selected_table = st.selectbox("LLM Model", system_prompt.keys(), index=None, + placeholder="Please select a model") + if model_selected_table is not None: + system_prompt_input = st.text_area('System Prompt', system_prompt[model_selected_table]) + user_prompt_input = st.text_area('User Prompt', user_prompt[model_selected_table], height=500) + + if st.button('Save', type='primary'): + # assign new system/user prompt by selected model + system_prompt[model_selected_table] = system_prompt_input + user_prompt[model_selected_table] = user_prompt_input + + # save new profile to DynamoDB + ProfileManagement.update_table_prompt(current_profile, system_prompt, user_prompt) + st.success('saved.') + + +if __name__ == '__main__': + main() diff --git "a/application/pages/5_\360\237\223\232_Index_Management.py" "b/application/pages/6_\360\237\223\232_Index_Management.py" similarity index 100% rename from "application/pages/5_\360\237\223\232_Index_Management.py" rename to "application/pages/6_\360\237\223\232_Index_Management.py" diff --git "a/application/pages/6_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" similarity index 100% rename from "application/pages/6_\360\237\223\232_Entity_Management.py" rename to "application/pages/7_\360\237\223\232_Entity_Management.py" diff --git "a/application/pages/7_\360\237\223\232_Agent_Cot_Management.py" "b/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" similarity index 100% rename from "application/pages/7_\360\237\223\232_Agent_Cot_Management.py" rename to "application/pages/8_\360\237\223\232_Agent_Cot_Management.py" diff --git "a/application/pages/8_\360\237\226\245_Suggested_Question_Management.py" "b/application/pages/9_\360\237\226\245_Suggested_Question_Management.py" similarity index 100% rename from "application/pages/8_\360\237\226\245_Suggested_Question_Management.py" rename to "application/pages/9_\360\237\226\245_Suggested_Question_Management.py" diff --git a/application/utils/navigation.py b/application/utils/navigation.py index f4e8a01..168adea 100644 --- a/application/utils/navigation.py +++ b/application/utils/navigation.py @@ -39,12 +39,13 @@ def make_sidebar(): st.page_link("pages/2_🪙_Data_Connection_Management.py", label="Data Connection Management", icon="🪙") st.page_link("pages/3_🪙_Data_Profile_Management.py", label="Data Profile Management", icon="🪙") st.page_link("pages/4_🪙_Schema_Description_Management.py", label="Schema Description Management", icon="🪙") + st.page_link("pages/5_🪙_Prompt_Management.py", label="Prompt Management", icon="🪙") st.markdown(":gray[Performance Enhancement]") - st.page_link("pages/5_📚_Index_Management.py", label="Index Management", icon="📚") - st.page_link("pages/6_📚_Entity_Management.py", label="Entity Management", icon="📚") - st.page_link("pages/7_📚_Agent_Cot_Management.py", label="Agent Cot Management", icon="📚") + st.page_link("pages/6_📚_Index_Management.py", label="Index Management", icon="📚") + st.page_link("pages/7_📚_Entity_Management.py", label="Entity Management", icon="📚") + st.page_link("pages/8_📚_Agent_Cot_Management.py", label="Agent Cot Management", icon="📚") st.markdown(":gray[Dashboard Customization Management]") - st.page_link("pages/8_🖥_Suggested_Question_Management.py", label="Suggested Question Management", + st.page_link("pages/9_🖥_Suggested_Question_Management.py", label="Suggested Question Management", icon="🖥") if st.button("Log out"):