18
18
from comps .agent .src .integrations .agent import instantiate_agent
19
19
from comps .agent .src .integrations .global_var import assistants_global_kv , threads_global_kv
20
20
from comps .agent .src .integrations .thread import instantiate_thread_memory , thread_completion_callback
21
- from comps .agent .src .integrations .utils import assemble_store_messages , get_args
21
+ from comps .agent .src .integrations .utils import assemble_store_messages , get_args , get_latest_human_message_from_store
22
22
from comps .cores .proto .api_protocol import (
23
23
AssistantsObject ,
24
24
ChatCompletionRequest ,
40
40
41
41
logger .info ("========initiating agent============" )
42
42
logger .info (f"args: { args } " )
43
- agent_inst = instantiate_agent (args , args . strategy , with_memory = args . with_memory )
43
+ agent_inst = instantiate_agent (args )
44
44
45
45
46
46
class AgentCompletionRequest (ChatCompletionRequest ):
@@ -76,7 +76,7 @@ async def llm_generate(input: AgentCompletionRequest):
76
76
if isinstance (input .messages , str ):
77
77
messages = input .messages
78
78
else :
79
- # TODO: need handle multi-turn messages
79
+ # last user message
80
80
messages = input .messages [- 1 ]["content" ]
81
81
82
82
# 2. prepare the input for the agent
@@ -90,7 +90,6 @@ async def llm_generate(input: AgentCompletionRequest):
90
90
else :
91
91
logger .info ("-----------NOT STREAMING-------------" )
92
92
response = await agent_inst .non_streaming_run (messages , config )
93
- logger .info ("-----------Response-------------" )
94
93
return GeneratedDoc (text = response , prompt = messages )
95
94
96
95
@@ -100,14 +99,14 @@ class RedisConfig(BaseModel):
100
99
101
100
class AgentConfig (BaseModel ):
102
101
stream : Optional [bool ] = False
103
- agent_name : Optional [str ] = "OPEA_Default_Agent "
102
+ agent_name : Optional [str ] = "OPEA_Agent "
104
103
strategy : Optional [str ] = "react_llama"
105
- role_description : Optional [str ] = "LLM enhanced agent "
104
+ role_description : Optional [str ] = "AI assistant "
106
105
tools : Optional [str ] = None
107
106
recursion_limit : Optional [int ] = 5
108
107
109
- model : Optional [str ] = "meta-llama/Meta- Llama-3-8B -Instruct"
110
- llm_engine : Optional [str ] = None
108
+ model : Optional [str ] = "meta-llama/Llama-3.3-70B -Instruct"
109
+ llm_engine : Optional [str ] = "vllm"
111
110
llm_endpoint_url : Optional [str ] = None
112
111
max_new_tokens : Optional [int ] = 1024
113
112
top_k : Optional [int ] = 10
@@ -117,10 +116,14 @@ class AgentConfig(BaseModel):
117
116
return_full_text : Optional [bool ] = False
118
117
custom_prompt : Optional [str ] = None
119
118
120
- # short/long term memory
121
- with_memory : Optional [bool ] = False
122
- # persistence
123
- with_store : Optional [bool ] = False
119
+ # # short/long term memory
120
+ with_memory : Optional [bool ] = True
121
+ # agent memory config
122
+ # chat_completion api: only supports checkpointer memory
123
+ # assistants api: supports checkpointer and store memory
124
+ # checkpointer: in-memory checkpointer - MemorySaver()
125
+ # store: redis store
126
+ memory_type : Optional [str ] = "checkpointer" # choices: checkpointer, store
124
127
store_config : Optional [RedisConfig ] = None
125
128
126
129
timeout : Optional [int ] = 60
@@ -147,18 +150,17 @@ class CreateAssistant(CreateAssistantsRequest):
147
150
)
148
151
def create_assistants (input : CreateAssistant ):
149
152
# 1. initialize the agent
150
- agent_inst = instantiate_agent (
151
- input .agent_config , input .agent_config .strategy , with_memory = input .agent_config .with_memory
152
- )
153
+ print ("@@@ Initializing agent with config: " , input .agent_config )
154
+ agent_inst = instantiate_agent (input .agent_config )
153
155
assistant_id = agent_inst .id
154
156
created_at = int (datetime .now ().timestamp ())
155
157
with assistants_global_kv as g_assistants :
156
158
g_assistants [assistant_id ] = (agent_inst , created_at )
157
159
logger .info (f"Record assistant inst { assistant_id } in global KV" )
158
160
159
- if input .agent_config .with_store :
161
+ if input .agent_config .memory_type == "store" :
160
162
logger .info ("Save Agent Config to database" )
161
- agent_inst .with_store = input .agent_config .with_store
163
+ # agent_inst.memory_type = input.agent_config.memory_type
162
164
print (input )
163
165
global db_client
164
166
if db_client is None :
@@ -172,6 +174,7 @@ def create_assistants(input: CreateAssistant):
172
174
return AssistantsObject (
173
175
id = assistant_id ,
174
176
created_at = created_at ,
177
+ model = input .agent_config .model ,
175
178
)
176
179
177
180
@@ -211,7 +214,7 @@ def create_messages(thread_id, input: CreateMessagesRequest):
211
214
if isinstance (input .content , str ):
212
215
query = input .content
213
216
else :
214
- query = input .content [- 1 ]["text" ]
217
+ query = input .content [- 1 ]["text" ] # content is a list of MessageContent
215
218
msg_id , created_at = thread_inst .add_query (query )
216
219
217
220
structured_content = MessageContent (text = query )
@@ -224,15 +227,18 @@ def create_messages(thread_id, input: CreateMessagesRequest):
224
227
assistant_id = input .assistant_id ,
225
228
)
226
229
227
- # save messages using assistant_id as key
230
+ # save messages using assistant_id_thread_id as key
228
231
if input .assistant_id is not None :
229
232
with assistants_global_kv as g_assistants :
230
233
agent_inst , _ = g_assistants [input .assistant_id ]
231
- if agent_inst .with_store :
232
- logger .info (f"Save Agent Messages, assistant_id: { input .assistant_id } , thread_id: { thread_id } " )
234
+ if agent_inst .memory_type == "store" :
235
+ logger .info (f"Save Messages, assistant_id: { input .assistant_id } , thread_id: { thread_id } " )
233
236
# if with store, db_client initialized already
234
237
global db_client
235
- db_client .put (msg_id , message .model_dump_json (), input .assistant_id )
238
+ namespace = f"{ input .assistant_id } _{ thread_id } "
239
+ # put(key: str, val: dict, collection: str = DEFAULT_COLLECTION)
240
+ db_client .put (msg_id , message .model_dump_json (), namespace )
241
+ logger .info (f"@@@ Save message to db: { msg_id } , { message .model_dump_json ()} , { namespace } " )
236
242
237
243
return message
238
244
@@ -254,15 +260,24 @@ def create_run(thread_id, input: CreateRunResponse):
254
260
with assistants_global_kv as g_assistants :
255
261
agent_inst , _ = g_assistants [assistant_id ]
256
262
257
- config = {"recursion_limit" : args .recursion_limit }
263
+ config = {
264
+ "recursion_limit" : args .recursion_limit ,
265
+ "configurable" : {"session_id" : thread_id , "thread_id" : thread_id , "user_id" : assistant_id },
266
+ }
258
267
259
- if agent_inst .with_store :
260
- # assemble multi-turn messages
268
+ if agent_inst .memory_type == "store" :
261
269
global db_client
262
- input_query = assemble_store_messages (db_client .get_all (assistant_id ))
270
+ namespace = f"{ assistant_id } _{ thread_id } "
271
+ # get the latest human message from store in the namespace
272
+ input_query = get_latest_human_message_from_store (db_client , namespace )
273
+ print ("@@@@ Input_query from store: " , input_query )
263
274
else :
264
275
input_query = thread_inst .get_query ()
276
+ print ("@@@@ Input_query from thread_inst: " , input_query )
265
277
278
+ print ("@@@ Agent instance:" )
279
+ print (agent_inst .id )
280
+ print (agent_inst .args )
266
281
try :
267
282
return StreamingResponse (
268
283
thread_completion_callback (agent_inst .stream_generator (input_query , config , thread_id ), thread_id ),
0 commit comments