Skip to content

Commit aab0ade

Browse files
Merge pull request #374 from UiPath/fix_llm_guardrails_check
fix: update llm guardrail payload extraction
2 parents 0dce3d2 + ff5e04b commit aab0ade

File tree

5 files changed

+201
-38
lines changed

5 files changed

+201
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.1.41"
3+
version = "0.1.42"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/agent/guardrails/guardrail_nodes.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818

1919
from uipath_langchain.agent.guardrails.types import ExecutionStage
2020
from uipath_langchain.agent.guardrails.utils import (
21-
_extract_tool_input_data,
21+
_extract_tool_args_from_message,
2222
_extract_tool_output_data,
23+
_extract_tools_args_from_message,
2324
get_message_content,
2425
)
2526
from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
@@ -188,7 +189,11 @@ def create_llm_guardrail_node(
188189
def _payload_generator(state: AgentGuardrailsGraphState) -> str:
189190
if not state.messages:
190191
return ""
191-
return get_message_content(state.messages[-1])
192+
match execution_stage:
193+
case ExecutionStage.PRE_EXECUTION:
194+
return get_message_content(state.messages[-1])
195+
case ExecutionStage.POST_EXECUTION:
196+
return json.dumps(_extract_tools_args_from_message(state.messages[-1]))
192197

193198
return _create_guardrail_node(
194199
guardrail,
@@ -273,16 +278,25 @@ def _payload_generator(state: AgentGuardrailsGraphState) -> str:
273278
return ""
274279

275280
if execution_stage == ExecutionStage.PRE_EXECUTION:
276-
# Extract tool args as dict and convert to JSON string
277-
args_dict = _extract_tool_input_data(state, tool_name, execution_stage)
281+
last_message = state.messages[-1]
282+
args_dict = _extract_tool_args_from_message(last_message, tool_name)
278283
if args_dict:
279284
return json.dumps(args_dict)
280285

281286
return get_message_content(state.messages[-1])
282287

283288
# Create closures for input/output data extraction (for deterministic guardrails)
284289
def _input_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
285-
return _extract_tool_input_data(state, tool_name, execution_stage)
290+
if execution_stage == ExecutionStage.PRE_EXECUTION:
291+
if len(state.messages) < 1:
292+
return {}
293+
message = state.messages[-1]
294+
else: # POST_EXECUTION
295+
if len(state.messages) < 2:
296+
return {}
297+
message = state.messages[-2]
298+
299+
return _extract_tool_args_from_message(message, tool_name)
286300

287301
def _output_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
288302
return _extract_tool_output_data(state)

src/uipath_langchain/agent/guardrails/utils.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
ToolMessage,
1111
)
1212

13-
from uipath_langchain.agent.guardrails.types import ExecutionStage
1413
from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
1514
from uipath_langchain.agent.tools.utils import sanitize_tool_name
1615

@@ -61,45 +60,42 @@ def _extract_tool_args_from_message(
6160
return parsed
6261
except json.JSONDecodeError:
6362
logger.warning(
64-
"Failed to parse tool args as JSON for tool '%s': %s",
65-
tool_name,
66-
args[:100] if len(args) > 100 else args,
63+
"Failed to parse tool args as JSON for tool '%s'", tool_name
6764
)
6865
return {}
6966

7067
return {}
7168

7269

73-
def _extract_tool_input_data(
74-
state: AgentGuardrailsGraphState, tool_name: str, execution_stage: ExecutionStage
75-
) -> dict[str, Any]:
76-
"""Extract tool call arguments as dict for deterministic guardrails.
70+
def _extract_tools_args_from_message(message: AnyMessage) -> list[dict[str, Any]]:
71+
if not isinstance(message, AIMessage):
72+
return []
7773

78-
Args:
79-
state: The current agent graph state.
80-
tool_name: Name of the tool to extract arguments from.
81-
execution_stage: PRE_EXECUTION or POST_EXECUTION.
74+
if not message.tool_calls:
75+
return []
8276

83-
Returns:
84-
Dict containing tool call arguments, or empty dict if not found.
85-
- For PRE_EXECUTION: extracts from last message
86-
- For POST_EXECUTION: extracts from second-to-last message
87-
"""
88-
if not state.messages:
89-
return {}
77+
result: list[dict[str, Any]] = []
9078

91-
# For PRE_EXECUTION, look at last message
92-
# For POST_EXECUTION, look at second-to-last message (before the ToolMessage)
93-
if execution_stage == ExecutionStage.PRE_EXECUTION:
94-
if len(state.messages) < 1:
95-
return {}
96-
message = state.messages[-1]
97-
else: # POST_EXECUTION
98-
if len(state.messages) < 2:
99-
return {}
100-
message = state.messages[-2]
101-
102-
return _extract_tool_args_from_message(message, tool_name)
79+
for tool_call in message.tool_calls:
80+
args = (
81+
tool_call.get("args")
82+
if isinstance(tool_call, dict)
83+
else getattr(tool_call, "args", None)
84+
)
85+
if args is not None:
86+
# Args should already be a dict
87+
if isinstance(args, dict):
88+
result.append(args)
89+
# If it's a JSON string, parse it
90+
elif isinstance(args, str):
91+
try:
92+
parsed = json.loads(args)
93+
if isinstance(parsed, dict):
94+
result.append(parsed)
95+
except json.JSONDecodeError:
96+
logger.warning("Failed to parse tool args as JSON")
97+
98+
return result
10399

104100

105101
def _extract_tool_output_data(state: AgentGuardrailsGraphState) -> dict[str, Any]:
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""Tests for guardrail utility functions."""
2+
3+
import json
4+
5+
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
6+
7+
from uipath_langchain.agent.guardrails.utils import (
8+
_extract_tool_args_from_message,
9+
_extract_tool_output_data,
10+
_extract_tools_args_from_message,
11+
get_message_content,
12+
)
13+
from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
14+
15+
16+
class TestExtractToolArgsFromMessage:
17+
"""Tests for _extract_tool_args_from_message function."""
18+
19+
def test_extracts_args_from_matching_tool(self):
20+
"""Should extract args from matching tool call."""
21+
message = AIMessage(
22+
content="",
23+
tool_calls=[
24+
{
25+
"name": "test_tool",
26+
"args": {"param1": "value1", "param2": 123},
27+
"id": "call_1",
28+
}
29+
],
30+
)
31+
result = _extract_tool_args_from_message(message, "test_tool")
32+
assert result == {"param1": "value1", "param2": 123}
33+
34+
def test_returns_empty_dict_for_non_matching_tool(self):
35+
"""Should return empty dict when tool name doesn't match."""
36+
message = AIMessage(
37+
content="",
38+
tool_calls=[
39+
{"name": "other_tool", "args": {"data": "value"}, "id": "call_1"}
40+
],
41+
)
42+
result = _extract_tool_args_from_message(message, "test_tool")
43+
assert result == {}
44+
45+
def test_returns_empty_dict_for_non_ai_message(self):
46+
"""Should return empty dict when message is not AIMessage."""
47+
message = HumanMessage(content="Test message")
48+
result = _extract_tool_args_from_message(message, "test_tool")
49+
assert result == {}
50+
51+
def test_returns_first_matching_tool_when_multiple(self):
52+
"""Should return args from first matching tool call."""
53+
message = AIMessage(
54+
content="",
55+
tool_calls=[
56+
{"name": "test_tool", "args": {"first": "call"}, "id": "call_1"},
57+
{"name": "test_tool", "args": {"second": "call"}, "id": "call_2"},
58+
],
59+
)
60+
result = _extract_tool_args_from_message(message, "test_tool")
61+
assert result == {"first": "call"}
62+
63+
64+
class TestExtractToolsArgsFromMessage:
65+
"""Tests for _extract_tools_args_from_message function."""
66+
67+
def test_extracts_args_from_all_tool_calls(self):
68+
"""Should extract args from all tool calls."""
69+
message = AIMessage(
70+
content="",
71+
tool_calls=[
72+
{"name": "tool1", "args": {"arg1": "val1"}, "id": "call_1"},
73+
{"name": "tool2", "args": {"arg2": "val2"}, "id": "call_2"},
74+
{"name": "tool3", "args": {"arg3": "val3"}, "id": "call_3"},
75+
],
76+
)
77+
result = _extract_tools_args_from_message(message)
78+
assert result == [{"arg1": "val1"}, {"arg2": "val2"}, {"arg3": "val3"}]
79+
80+
def test_returns_empty_list_for_non_ai_message(self):
81+
"""Should return empty list when message is not AIMessage."""
82+
message = HumanMessage(content="Test message")
83+
result = _extract_tools_args_from_message(message)
84+
assert result == []
85+
86+
def test_returns_empty_list_when_no_tool_calls(self):
87+
"""Should return empty list when AIMessage has no tool calls."""
88+
message = AIMessage(content="Test response")
89+
result = _extract_tools_args_from_message(message)
90+
assert result == []
91+
92+
93+
class TestExtractToolOutputData:
94+
"""Tests for _extract_tool_output_data function."""
95+
96+
def test_extracts_json_dict_content(self):
97+
"""Should parse and return dict when content is JSON string."""
98+
json_content = json.dumps({"result": "success", "data": {"value": 42}})
99+
state = AgentGuardrailsGraphState(
100+
messages=[ToolMessage(content=json_content, tool_call_id="call_1")]
101+
)
102+
result = _extract_tool_output_data(state)
103+
assert result == {"result": "success", "data": {"value": 42}}
104+
105+
def test_wraps_non_json_string_in_output_field(self):
106+
"""Should wrap non-JSON string content in 'output' field."""
107+
state = AgentGuardrailsGraphState(
108+
messages=[ToolMessage(content="Plain text result", tool_call_id="call_1")]
109+
)
110+
result = _extract_tool_output_data(state)
111+
assert result == {"output": "Plain text result"}
112+
113+
def test_returns_empty_dict_for_empty_messages(self):
114+
"""Should return empty dict when state has no messages."""
115+
state = AgentGuardrailsGraphState(messages=[])
116+
result = _extract_tool_output_data(state)
117+
assert result == {}
118+
119+
def test_returns_empty_dict_for_non_tool_message(self):
120+
"""Should return empty dict when last message is not ToolMessage."""
121+
state = AgentGuardrailsGraphState(
122+
messages=[AIMessage(content="Not a tool message")]
123+
)
124+
result = _extract_tool_output_data(state)
125+
assert result == {}
126+
127+
128+
class TestGetMessageContent:
129+
"""Tests for get_message_content function."""
130+
131+
def test_extracts_string_content_from_human_message(self):
132+
"""Should extract string content from HumanMessage."""
133+
message = HumanMessage(content="Hello from human")
134+
result = get_message_content(message)
135+
assert result == "Hello from human"
136+
137+
def test_extracts_content_from_ai_message(self):
138+
"""Should extract content from AIMessage."""
139+
message = AIMessage(content="AI response")
140+
result = get_message_content(message)
141+
assert result == "AI response"
142+
143+
def test_extracts_content_from_tool_message(self):
144+
"""Should extract content from ToolMessage."""
145+
message = ToolMessage(content="Tool result", tool_call_id="call_1")
146+
result = get_message_content(message)
147+
assert result == "Tool result"
148+
149+
def test_handles_empty_content(self):
150+
"""Should handle empty content string."""
151+
message = AIMessage(content="")
152+
result = get_message_content(message)
153+
assert result == ""

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)