Skip to content
This repository was archived by the owner on Jun 22, 2025. It is now read-only.

Commit 7678bde

Browse files
Improve multi-turn capability for agent (opea-project#1248)
* first code for multi-turn Signed-off-by: minmin-intel <[email protected]> * test redispersistence Signed-off-by: minmin-intel <[email protected]> * integrate persistent store in react llama Signed-off-by: minmin-intel <[email protected]> * test multi-turn Signed-off-by: minmin-intel <[email protected]> * multiturn for assistants api and chatcompletion api Signed-off-by: minmin-intel <[email protected]> * update readme and ut script Signed-off-by: minmin-intel <[email protected]> * update readme and ut scripts Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix bug Signed-off-by: minmin-intel <[email protected]> * change memory type naming Signed-off-by: minmin-intel <[email protected]> * fix with_memory as str Signed-off-by: minmin-intel <[email protected]> --------- Signed-off-by: minmin-intel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5476aab commit 7678bde

19 files changed

+813
-335
lines changed

comps/agent/src/README.md

Lines changed: 89 additions & 48 deletions
Large diffs are not rendered by default.

comps/agent/src/agent.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from comps.agent.src.integrations.agent import instantiate_agent
1919
from comps.agent.src.integrations.global_var import assistants_global_kv, threads_global_kv
2020
from comps.agent.src.integrations.thread import instantiate_thread_memory, thread_completion_callback
21-
from comps.agent.src.integrations.utils import assemble_store_messages, get_args
21+
from comps.agent.src.integrations.utils import assemble_store_messages, get_args, get_latest_human_message_from_store
2222
from comps.cores.proto.api_protocol import (
2323
AssistantsObject,
2424
ChatCompletionRequest,
@@ -40,7 +40,7 @@
4040

4141
logger.info("========initiating agent============")
4242
logger.info(f"args: {args}")
43-
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)
43+
agent_inst = instantiate_agent(args)
4444

4545

4646
class AgentCompletionRequest(ChatCompletionRequest):
@@ -76,7 +76,7 @@ async def llm_generate(input: AgentCompletionRequest):
7676
if isinstance(input.messages, str):
7777
messages = input.messages
7878
else:
79-
# TODO: need handle multi-turn messages
79+
# last user message
8080
messages = input.messages[-1]["content"]
8181

8282
# 2. prepare the input for the agent
@@ -90,7 +90,6 @@ async def llm_generate(input: AgentCompletionRequest):
9090
else:
9191
logger.info("-----------NOT STREAMING-------------")
9292
response = await agent_inst.non_streaming_run(messages, config)
93-
logger.info("-----------Response-------------")
9493
return GeneratedDoc(text=response, prompt=messages)
9594

9695

@@ -100,14 +99,14 @@ class RedisConfig(BaseModel):
10099

101100
class AgentConfig(BaseModel):
102101
stream: Optional[bool] = False
103-
agent_name: Optional[str] = "OPEA_Default_Agent"
102+
agent_name: Optional[str] = "OPEA_Agent"
104103
strategy: Optional[str] = "react_llama"
105-
role_description: Optional[str] = "LLM enhanced agent"
104+
role_description: Optional[str] = "AI assistant"
106105
tools: Optional[str] = None
107106
recursion_limit: Optional[int] = 5
108107

109-
model: Optional[str] = "meta-llama/Meta-Llama-3-8B-Instruct"
110-
llm_engine: Optional[str] = None
108+
model: Optional[str] = "meta-llama/Llama-3.3-70B-Instruct"
109+
llm_engine: Optional[str] = "vllm"
111110
llm_endpoint_url: Optional[str] = None
112111
max_new_tokens: Optional[int] = 1024
113112
top_k: Optional[int] = 10
@@ -117,10 +116,14 @@ class AgentConfig(BaseModel):
117116
return_full_text: Optional[bool] = False
118117
custom_prompt: Optional[str] = None
119118

120-
# short/long term memory
121-
with_memory: Optional[bool] = False
122-
# persistence
123-
with_store: Optional[bool] = False
119+
# # short/long term memory
120+
with_memory: Optional[bool] = True
121+
# agent memory config
122+
# chat_completion api: only supports checkpointer memory
123+
# assistants api: supports checkpointer and store memory
124+
# checkpointer: in-memory checkpointer - MemorySaver()
125+
# store: redis store
126+
memory_type: Optional[str] = "checkpointer" # choices: checkpointer, store
124127
store_config: Optional[RedisConfig] = None
125128

126129
timeout: Optional[int] = 60
@@ -147,18 +150,17 @@ class CreateAssistant(CreateAssistantsRequest):
147150
)
148151
def create_assistants(input: CreateAssistant):
149152
# 1. initialize the agent
150-
agent_inst = instantiate_agent(
151-
input.agent_config, input.agent_config.strategy, with_memory=input.agent_config.with_memory
152-
)
153+
print("@@@ Initializing agent with config: ", input.agent_config)
154+
agent_inst = instantiate_agent(input.agent_config)
153155
assistant_id = agent_inst.id
154156
created_at = int(datetime.now().timestamp())
155157
with assistants_global_kv as g_assistants:
156158
g_assistants[assistant_id] = (agent_inst, created_at)
157159
logger.info(f"Record assistant inst {assistant_id} in global KV")
158160

159-
if input.agent_config.with_store:
161+
if input.agent_config.memory_type == "store":
160162
logger.info("Save Agent Config to database")
161-
agent_inst.with_store = input.agent_config.with_store
163+
# agent_inst.memory_type = input.agent_config.memory_type
162164
print(input)
163165
global db_client
164166
if db_client is None:
@@ -172,6 +174,7 @@ def create_assistants(input: CreateAssistant):
172174
return AssistantsObject(
173175
id=assistant_id,
174176
created_at=created_at,
177+
model=input.agent_config.model,
175178
)
176179

177180

@@ -211,7 +214,7 @@ def create_messages(thread_id, input: CreateMessagesRequest):
211214
if isinstance(input.content, str):
212215
query = input.content
213216
else:
214-
query = input.content[-1]["text"]
217+
query = input.content[-1]["text"] # content is a list of MessageContent
215218
msg_id, created_at = thread_inst.add_query(query)
216219

217220
structured_content = MessageContent(text=query)
@@ -224,15 +227,18 @@ def create_messages(thread_id, input: CreateMessagesRequest):
224227
assistant_id=input.assistant_id,
225228
)
226229

227-
# save messages using assistant_id as key
230+
# save messages using assistant_id_thread_id as key
228231
if input.assistant_id is not None:
229232
with assistants_global_kv as g_assistants:
230233
agent_inst, _ = g_assistants[input.assistant_id]
231-
if agent_inst.with_store:
232-
logger.info(f"Save Agent Messages, assistant_id: {input.assistant_id}, thread_id: {thread_id}")
234+
if agent_inst.memory_type == "store":
235+
logger.info(f"Save Messages, assistant_id: {input.assistant_id}, thread_id: {thread_id}")
233236
# if with store, db_client initialized already
234237
global db_client
235-
db_client.put(msg_id, message.model_dump_json(), input.assistant_id)
238+
namespace = f"{input.assistant_id}_{thread_id}"
239+
# put(key: str, val: dict, collection: str = DEFAULT_COLLECTION)
240+
db_client.put(msg_id, message.model_dump_json(), namespace)
241+
logger.info(f"@@@ Save message to db: {msg_id}, {message.model_dump_json()}, {namespace}")
236242

237243
return message
238244

@@ -254,15 +260,24 @@ def create_run(thread_id, input: CreateRunResponse):
254260
with assistants_global_kv as g_assistants:
255261
agent_inst, _ = g_assistants[assistant_id]
256262

257-
config = {"recursion_limit": args.recursion_limit}
263+
config = {
264+
"recursion_limit": args.recursion_limit,
265+
"configurable": {"session_id": thread_id, "thread_id": thread_id, "user_id": assistant_id},
266+
}
258267

259-
if agent_inst.with_store:
260-
# assemble multi-turn messages
268+
if agent_inst.memory_type == "store":
261269
global db_client
262-
input_query = assemble_store_messages(db_client.get_all(assistant_id))
270+
namespace = f"{assistant_id}_{thread_id}"
271+
# get the latest human message from store in the namespace
272+
input_query = get_latest_human_message_from_store(db_client, namespace)
273+
print("@@@@ Input_query from store: ", input_query)
263274
else:
264275
input_query = thread_inst.get_query()
276+
print("@@@@ Input_query from thread_inst: ", input_query)
265277

278+
print("@@@ Agent instance:")
279+
print(agent_inst.id)
280+
print(agent_inst.args)
266281
try:
267282
return StreamingResponse(
268283
thread_completion_callback(agent_inst.stream_generator(input_query, config, thread_id), thread_id),

comps/agent/src/integrations/agent.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
3+
from .storage.persistence_redis import RedisPersistence
34
from .utils import load_python_prompt
45

56

6-
def instantiate_agent(args, strategy="react_langchain", with_memory=False):
7+
def instantiate_agent(args):
8+
strategy = args.strategy
9+
with_memory = args.with_memory
10+
711
if args.custom_prompt is not None:
812
print(f">>>>>> custom_prompt enabled, {args.custom_prompt}")
913
custom_prompt = load_python_prompt(args.custom_prompt)
@@ -22,7 +26,7 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False):
2226
print("Initializing ReAct Agent with LLAMA")
2327
from .strategy.react import ReActAgentLlama
2428

25-
return ReActAgentLlama(args, with_memory, custom_prompt=custom_prompt)
29+
return ReActAgentLlama(args, custom_prompt=custom_prompt)
2630
elif strategy == "plan_execute":
2731
from .strategy.planexec import PlanExecuteAgentWithLangGraph
2832

comps/agent/src/integrations/strategy/base_agent.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from uuid import uuid4
55

6+
from langgraph.checkpoint.memory import MemorySaver
7+
8+
from ..storage.persistence_redis import RedisPersistence
69
from ..tools import get_tools_descriptions
710
from ..utils import adapt_custom_prompt, setup_chat_model
811

@@ -12,11 +15,25 @@ def __init__(self, args, local_vars=None, **kwargs) -> None:
1215
self.llm = setup_chat_model(args)
1316
self.tools_descriptions = get_tools_descriptions(args.tools)
1417
self.app = None
15-
self.memory = None
1618
self.id = f"assistant_{self.__class__.__name__}_{uuid4()}"
1719
self.args = args
1820
adapt_custom_prompt(local_vars, kwargs.get("custom_prompt"))
19-
print(self.tools_descriptions)
21+
print("Registered tools: ", self.tools_descriptions)
22+
23+
if args.with_memory:
24+
if args.memory_type == "checkpointer":
25+
self.memory_type = "checkpointer"
26+
self.checkpointer = MemorySaver()
27+
self.store = None
28+
elif args.memory_type == "store":
29+
# print("Using Redis as store: ", args.store_config.redis_uri)
30+
self.store = RedisPersistence(args.store_config.redis_uri)
31+
self.memory_type = "store"
32+
else:
33+
raise ValueError("Invalid memory type!")
34+
else:
35+
self.store = None
36+
self.checkpointer = None
2037

2138
@property
2239
def is_vllm(self):
@@ -60,10 +77,7 @@ async def non_streaming_run(self, query, config):
6077
try:
6178
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
6279
message = s["messages"][-1]
63-
if isinstance(message, tuple):
64-
print(message)
65-
else:
66-
message.pretty_print()
80+
message.pretty_print()
6781

6882
last_message = s["messages"][-1]
6983
print("******Response: ", last_message.content)

0 commit comments

Comments
 (0)