-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli_simple_chat.py
More file actions
125 lines (113 loc) · 5.28 KB
/
cli_simple_chat.py
File metadata and controls
125 lines (113 loc) · 5.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from api_generator import load_config, initialize_client, generate_api_response
from prompt.chat_prompt import ChatContext
from prompt.character_prompt import get_available_characters, get_character_info
from typing import Optional
class ChatSession:
def __init__(self, character_type: str, db_manager=None, session_id: Optional[int]=None):
self.character_type = character_type
self.db_manager = db_manager
self.session_id = session_id
# 使用 basic 详细程度创建聊天上下文
self.chat_context = ChatContext(character_type, "basic")
self.character_info = get_character_info(character_type)
if self.session_id and self.db_manager:
# 尝试加载持久化的长对话上下文
context = self.db_manager.load_context(self.session_id)
if context:
self.chat_context.dialogue_context.context = context
else:
if self.db_manager is None:
from database.db_manager import DatabaseManager
self.db_manager = DatabaseManager()
self.session_id = self.db_manager.create_session(
self.character_type,
self.character_info.get("name", "未知"),
self.character_info.get("description", "")
)
def add_message(self, role: str, content: str):
self.chat_context.add_message(role, content)
if self.db_manager and self.session_id:
self.db_manager.add_message(self.session_id, role, content)
def get_character_name(self) -> str:
return self.character_info.get("name", "未知")
def get_character_description(self) -> str:
return self.character_info.get("description", "")
def main():
# 加载配置并初始化 OpenAI 客户端
config = load_config()
client = initialize_client(config)
if client is None:
print("无法初始化OpenAI客户端。请检查配置。")
return
# 显示菜单选项:加载模型或开启新对话
print("请选择操作:")
print("1. 加载模型(加载历史对话)")
print("2. 开启新对话")
choice = input("请输入操作编号:").strip()
if choice == "1":
# 加载历史对话:从数据库中加载已存储的长对话会话
from database.db_manager import DatabaseManager
db_manager = DatabaseManager()
sessions = db_manager.get_available_sessions()
if not sessions:
print("没有找到历史对话,切换到新对话模式。")
choice = "2"
else:
print("可用历史对话会话:")
for idx, session in enumerate(sessions, start=1):
# 显示会话ID、角色名称和开始时间
print(f"{idx}. 会话ID: {session['session_id']} 角色: {session['name']} 开始时间: {session['start_time']}")
sess_choice = input("请选择会话编号:").strip()
try:
sess_index = int(sess_choice) - 1
if sess_index < 0 or sess_index >= len(sessions):
raise ValueError
selected_session = sessions[sess_index]
char_type = selected_session["character_type"]
session_id = selected_session["session_id"]
except Exception:
print("无效选择,切换到新对话模式。")
choice = "2"
elif choice == "2":
# 开启新对话:显示可用角色列表供选择(数字选择)
available_characters = get_available_characters()
print("可用角色类型:")
for idx, role in enumerate(available_characters, start=1):
print(f"{idx}. {role}")
role_choice = input("请选择角色编号:").strip()
try:
role_index = int(role_choice) - 1
if role_index < 0 or role_index >= len(available_characters):
raise ValueError
char_type = available_characters[role_index]
session_id = None # 新对话,无历史会话编号
except Exception:
print("无效选择,使用默认角色 'default'")
char_type = "default"
session_id = None
elif choice not in ["1", "2"]:
print("无效选择,退出。")
return
# 创建会话对象,若加载历史对话则传入 db_manager 和 session_id
if choice == "1" and 'session_id' in locals() and session_id is not None:
chat_session = ChatSession(char_type, db_manager, session_id)
else:
chat_session = ChatSession(char_type)
# 显示角色基本信息及欢迎提示
print(f"\n已选择角色: {chat_session.get_character_name()}")
print(f"角色描述: {chat_session.get_character_description()}\n")
print("开始聊天,输入 'exit' 或 'quit' 退出。\n")
while True:
try:
user_input = input("> ")
except KeyboardInterrupt:
print("\n退出...")
break
if user_input.lower() in ["exit", "quit"]:
print("退出...")
break
# 调用 API 接口生成回复,并打印
response = generate_api_response(user_input, chat_session, client, config)
print(f"\n{chat_session.get_character_name()}: {response}\n")
if __name__ == "__main__":
main()