Skip to content

Commit 12ea6e0

Browse files
minmin-intelcwlacewe
authored andcommitted
enable custom prompt for react_llama and react_langgraph (opea-project#1391)
Signed-off-by: minmin-intel <[email protected]> Signed-off-by: Lacewell, Chaunte W <[email protected]>
1 parent 046eed0 commit 12ea6e0

File tree

4 files changed

+74
-48
lines changed

4 files changed

+74
-48
lines changed

comps/agent/src/integrations/strategy/react/planner.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from ...storage.persistence_redis import RedisPersistence
1515
from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
1616
from ..base_agent import BaseAgent
17-
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
1817

1918

2019
class ReActAgentwithLangchain(BaseAgent):
2120
def __init__(self, args, with_memory=False, **kwargs):
2221
super().__init__(args, local_vars=globals(), **kwargs)
22+
from .prompt import hwchase17_react_prompt
23+
2324
prompt = hwchase17_react_prompt
2425
if has_multi_tool_inputs(self.tools_descriptions):
2526
raise ValueError("Only supports single input tools when using strategy == react_langchain")
@@ -86,7 +87,12 @@ async def stream_generator(self, query, config, thread_id=None):
8687
class ReActAgentwithLanggraph(BaseAgent):
8788
def __init__(self, args, with_memory=False, **kwargs):
8889
super().__init__(args, local_vars=globals(), **kwargs)
89-
90+
if kwargs.get("custom_prompt") is not None:
91+
print("***Custom prompt is provided.")
92+
REACT_SYS_MESSAGE = kwargs.get("custom_prompt").REACT_SYS_MESSAGE
93+
else:
94+
print("*** Using default prompt.")
95+
from .prompt import REACT_SYS_MESSAGE
9096
tools = self.tools_descriptions
9197
print("REACT_SYS_MESSAGE: ", REACT_SYS_MESSAGE)
9298

@@ -174,10 +180,18 @@ class ReActAgentNodeLlama:
174180
A workaround for open-source llm served by TGI-gaudi.
175181
"""
176182

177-
def __init__(self, tools, args, store=None):
178-
from .prompt import REACT_AGENT_LLAMA_PROMPT
183+
def __init__(self, tools, args, store=None, **kwargs):
179184
from .utils import ReActLlamaOutputParser
180185

186+
if kwargs.get("custom_prompt") is not None:
187+
print("***Custom prompt is provided.")
188+
REACT_AGENT_LLAMA_PROMPT = kwargs.get("custom_prompt").REACT_AGENT_LLAMA_PROMPT
189+
else:
190+
print("*** Using default prompt.")
191+
from .prompt import REACT_AGENT_LLAMA_PROMPT
192+
193+
print("***Prompt template:\n", REACT_AGENT_LLAMA_PROMPT)
194+
181195
output_parser = ReActLlamaOutputParser()
182196
prompt = PromptTemplate(
183197
template=REACT_AGENT_LLAMA_PROMPT,
@@ -244,6 +258,8 @@ def __call__(self, state, config):
244258
ai_message = AIMessage(content=response, tool_calls=tool_calls)
245259
elif "answer" in output[0]:
246260
ai_message = AIMessage(content=str(output[0]["answer"]))
261+
else:
262+
ai_message = AIMessage(content=response)
247263
else:
248264
ai_message = AIMessage(content=response)
249265

@@ -254,7 +270,7 @@ class ReActAgentLlama(BaseAgent):
254270
def __init__(self, args, **kwargs):
255271
super().__init__(args, local_vars=globals(), **kwargs)
256272

257-
agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args, store=self.store)
273+
agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args, store=self.store, **kwargs)
258274
tool_node = ToolNode(self.tools_descriptions)
259275

260276
workflow = StateGraph(AgentState)

comps/agent/src/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ langchain
88
langchain-google-community
99
langchain-huggingface
1010
langchain-openai
11+
langchain-redis
1112
langchain_community
1213
langchainhub
1314
langgraph
@@ -22,6 +23,7 @@ pandas
2223
prometheus_fastapi_instrumentator
2324
pyarrow
2425
pydantic #==1.10.13
26+
rank_bm25
2527

2628
# used by document loader
2729
# beautifulsoup4

comps/agent/src/test.py

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,13 @@
1414
def test_agent_local(args):
1515
from integrations.agent import instantiate_agent
1616

17-
if args.q == 0:
18-
df = pd.DataFrame({"query": ["What is the Intel OPEA Project?"]})
19-
elif args.q == 1:
20-
df = pd.DataFrame({"query": ["what is the trade volume for Microsoft today?"]})
21-
elif args.q == 2:
22-
df = pd.DataFrame({"query": ["what is the hometown of Year 2023 Australia open winner?"]})
23-
24-
agent = instantiate_agent(args, strategy=args.strategy)
25-
app = agent.app
17+
agent = instantiate_agent(args)
2618

2719
config = {"recursion_limit": args.recursion_limit}
2820

29-
traces = []
30-
success = 0
31-
for _, row in df.iterrows():
32-
print("Query: ", row["query"])
33-
initial_state = {"messages": [{"role": "user", "content": row["query"]}]}
34-
try:
35-
trace = {"query": row["query"], "trace": []}
36-
for event in app.stream(initial_state, config=config):
37-
trace["trace"].append(event)
38-
for k, v in event.items():
39-
print("{}: {}".format(k, v))
40-
41-
traces.append(trace)
42-
success += 1
43-
except Exception as e:
44-
print(str(e), str(traceback.format_exc()))
45-
traces.append({"query": row["query"], "trace": str(e)})
46-
47-
print("-" * 50)
21+
query = "What is OPEA project?"
4822

49-
df["trace"] = traces
50-
df.to_csv(os.path.join(args.filedir, args.output), index=False)
51-
print(f"succeed: {success}/{len(df)}")
23+
# run_agent(agent, config, query)
5224

5325

5426
def test_agent_http(args):
@@ -158,15 +130,12 @@ def test_ut(args):
158130
def run_agent(agent, config, input_message):
159131
initial_state = agent.prepare_initial_state(input_message)
160132

161-
try:
162-
for s in agent.app.stream(initial_state, config=config, stream_mode="values"):
163-
message = s["messages"][-1]
164-
message.pretty_print()
133+
for s in agent.app.stream(initial_state, config=config, stream_mode="values"):
134+
message = s["messages"][-1]
135+
message.pretty_print()
165136

166-
last_message = s["messages"][-1]
167-
print("******Response: ", last_message.content)
168-
except Exception as e:
169-
print(str(e))
137+
last_message = s["messages"][-1]
138+
print("******Response: ", last_message.content)
170139

171140

172141
def stream_generator(agent, config, input_message):
@@ -309,4 +278,5 @@ def test_memory(args):
309278
# else:
310279
# print("Please specify the test type")
311280

312-
test_memory(args)
281+
# test_memory(args)
282+
test_agent_local(args)
Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,49 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4-
REACT_SYS_MESSAGE = """\
5-
Custom_prmpt !!!!!!!!!! Decompose the user request into a series of simple tasks when necessary and solve the problem step by step.
4+
REACT_SYS_MESSAGE = """CUSTOM PROMPT
5+
Decompose the user request into a series of simple tasks when necessary and solve the problem step by step.
66
When you cannot get the answer at first, do not give up. Reflect on the info you have from the tools and try to solve the problem in a different way.
77
Please follow these guidelines when formulating your answer:
88
1. If the question contains a false premise or assumption, answer “invalid question”.
9-
2. If you are uncertain or do not know the answer, respond with “I dont know”.
9+
2. If you are uncertain or do not know the answer, respond with “I don't know”.
1010
3. Give concise, factual and relevant answers.
1111
"""
12+
13+
REACT_AGENT_LLAMA_PROMPT = """FINANCIAL ANALYST ASSISTANT
14+
You are a helpful assistant engaged in multi-turn conversations with Financial analysts.
15+
You have access to the following two tools:
16+
{tools}
17+
18+
**Procedure:**
19+
1. Read the question carefully. Divide the question into sub-questions and conquer sub-questions one by one.
20+
3. If there is execution history, read it carefully and reason about the information gathered so far and decide if you can answer the question or if you need to call more tools.
21+
22+
**Output format:**
23+
You should output your thought process. Finish thinking first. Output tool calls or your answer at the end.
24+
When making tool calls, you should use the following format:
25+
TOOL CALL: {{"tool": "tool1", "args": {{"arg1": "value1", "arg2": "value2", ...}}}}
26+
27+
If you can answer the question, provide the answer in the following format:
28+
FINAL ANSWER: {{"answer": "your answer here"}}
29+
30+
31+
======= Conversations with user in previous turns =======
32+
{thread_history}
33+
======= End of previous conversations =======
34+
35+
======= Your execution History in this turn =========
36+
{history}
37+
======= End of execution history ==========
38+
39+
**Tips:**
40+
* You may need to do multi-hop calculations and call tools multiple times to get an answer.
41+
* Do not assume any financial figures. Always rely on the tools to get the factual information.
42+
* If you need a certain financial figure, search for the figure instead of the financial statement name.
43+
* If you did not get the answer at first, do not give up. Reflect on the steps that you have taken and try a different way. Think out of the box. You hard work will be rewarded.
44+
* Give concise, factual and relevant answers.
45+
* If the user question is too ambiguous, ask for clarification.
46+
47+
Now take a deep breath and think step by step to answer user's question in this turn.
48+
USER MESSAGE: {input}
49+
"""

0 commit comments

Comments
 (0)