Skip to content

Commit 731d81b

Browse files
Merge pull request #364 from UiPath/feat/AL-255_guardrails_filter_action
feat: Add deterministic guardrails filtering action
2 parents f36fd48 + 6e452d2 commit 731d81b

File tree

8 files changed

+957
-30
lines changed

8 files changed

+957
-30
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.1.38"
3+
version = "0.1.39"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/agent/guardrails/actions/filter_action.py

Lines changed: 247 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import re
22
from typing import Any
33

4+
from langchain_core.messages import AIMessage, ToolMessage
5+
from langgraph.types import Command
6+
from uipath.core.guardrails.guardrails import FieldReference, FieldSource
47
from uipath.platform.guardrails import BaseGuardrail, GuardrailScope
58
from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode
69

@@ -14,11 +17,18 @@
1417
class FilterAction(GuardrailAction):
1518
"""Action that filters inputs/outputs on guardrail failure.
1619
17-
For now, filtering is only supported for non-AGENT and non-LLM scopes.
18-
If invoked for ``GuardrailScope.AGENT`` or ``GuardrailScope.LLM``, this action
19-
raises an exception to indicate the operation is not supported yet.
20+
For Tool scope, this action removes specified fields from tool call arguments.
21+
For AGENT and LLM scopes, this action raises an exception as it's not supported yet.
2022
"""
2123

24+
def __init__(self, fields: list[FieldReference] | None = None):
25+
"""Initialize FilterAction with fields to filter.
26+
27+
Args:
28+
fields: List of FieldReference objects specifying which fields to filter.
29+
"""
30+
self.fields = fields or []
31+
2232
def action_node(
2333
self,
2434
*,
@@ -41,15 +51,240 @@ def action_node(
4151
raw_node_name = f"{scope.name}_{execution_stage.name}_{guardrail.name}_filter"
4252
node_name = re.sub(r"\W+", "_", raw_node_name.lower()).strip("_")
4353

44-
async def _node(_state: AgentGuardrailsGraphState) -> dict[str, Any]:
45-
if scope in (GuardrailScope.AGENT, GuardrailScope.LLM):
46-
raise AgentTerminationException(
47-
code=UiPathErrorCode.EXECUTION_ERROR,
48-
title="Guardrail filter action not supported",
49-
detail=f"FilterAction is not supported for scope [{scope.name}] at this time.",
50-
category=UiPathErrorCategory.USER,
54+
async def _node(
55+
_state: AgentGuardrailsGraphState,
56+
) -> dict[str, Any] | Command[Any]:
57+
if scope == GuardrailScope.TOOL:
58+
return _filter_tool_fields(
59+
_state,
60+
self.fields,
61+
execution_stage,
62+
guarded_component_name,
63+
guardrail.name,
5164
)
52-
# No-op for other scopes for now.
53-
return {}
65+
66+
raise AgentTerminationException(
67+
code=UiPathErrorCode.EXECUTION_ERROR,
68+
title="Guardrail filter action not supported",
69+
detail=f"FilterAction is not supported for scope [{scope.name}] at this time.",
70+
category=UiPathErrorCategory.USER,
71+
)
5472

5573
return node_name, _node
74+
75+
76+
def _filter_tool_fields(
77+
state: AgentGuardrailsGraphState,
78+
fields_to_filter: list[FieldReference],
79+
execution_stage: ExecutionStage,
80+
tool_name: str,
81+
guardrail_name: str,
82+
) -> dict[str, Any] | Command[Any]:
83+
"""Filter specified fields from tool call arguments or tool output.
84+
85+
The filter action filters fields based on the execution stage:
86+
- PRE_EXECUTION: Only input fields are filtered
87+
- POST_EXECUTION: Only output fields are filtered
88+
89+
Args:
90+
state: The current agent graph state.
91+
fields_to_filter: List of FieldReference objects specifying which fields to filter.
92+
execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
93+
tool_name: Name of the tool to filter.
94+
guardrail_name: Name of the guardrail for logging purposes.
95+
96+
Returns:
97+
Command to update messages with filtered tool call args or output.
98+
99+
Raises:
100+
AgentTerminationException: If filtering fails.
101+
"""
102+
try:
103+
if not fields_to_filter:
104+
return {}
105+
106+
if execution_stage == ExecutionStage.PRE_EXECUTION:
107+
return _filter_tool_input_fields(state, fields_to_filter, tool_name)
108+
else:
109+
return _filter_tool_output_fields(state, fields_to_filter)
110+
111+
except Exception as e:
112+
raise AgentTerminationException(
113+
code=UiPathErrorCode.EXECUTION_ERROR,
114+
title="Filter action failed",
115+
detail=f"Failed to filter tool fields: {str(e)}",
116+
category=UiPathErrorCategory.USER,
117+
) from e
118+
119+
120+
def _filter_tool_input_fields(
121+
state: AgentGuardrailsGraphState,
122+
fields_to_filter: list[FieldReference],
123+
tool_name: str,
124+
) -> dict[str, Any] | Command[Any]:
125+
"""Filter specified input fields from tool call arguments (PRE_EXECUTION only).
126+
127+
This function is called at PRE_EXECUTION to filter input fields from tool call arguments
128+
before the tool is executed.
129+
130+
Args:
131+
state: The current agent graph state.
132+
fields_to_filter: List of FieldReference objects specifying which fields to filter.
133+
tool_name: Name of the tool to filter.
134+
135+
Returns:
136+
Command to update messages with filtered tool call args, or empty dict if no input fields to filter.
137+
"""
138+
# Check if there are any input fields to filter
139+
has_input_fields = any(
140+
field_ref.source == FieldSource.INPUT for field_ref in fields_to_filter
141+
)
142+
143+
if not has_input_fields:
144+
return {}
145+
146+
msgs = state.messages.copy()
147+
if not msgs:
148+
return {}
149+
150+
# Find the AIMessage with tool calls
151+
# At PRE_EXECUTION, this is always the last message
152+
ai_message = None
153+
for i in range(len(msgs) - 1, -1, -1):
154+
msg = msgs[i]
155+
if isinstance(msg, AIMessage) and msg.tool_calls:
156+
ai_message = msg
157+
break
158+
159+
if ai_message is None:
160+
return {}
161+
162+
# Find and filter the tool call with matching name
163+
# Type assertion: we know ai_message is AIMessage from the check above
164+
assert isinstance(ai_message, AIMessage)
165+
tool_calls = list(ai_message.tool_calls)
166+
modified = False
167+
168+
for tool_call in tool_calls:
169+
call_name = (
170+
tool_call.get("name")
171+
if isinstance(tool_call, dict)
172+
else getattr(tool_call, "name", None)
173+
)
174+
175+
if call_name == tool_name:
176+
# Get the current args
177+
args = (
178+
tool_call.get("args")
179+
if isinstance(tool_call, dict)
180+
else getattr(tool_call, "args", None)
181+
)
182+
183+
if args and isinstance(args, dict):
184+
# Filter out the specified input fields
185+
filtered_args = args.copy()
186+
for field_ref in fields_to_filter:
187+
# Only filter input fields
188+
if (
189+
field_ref.source == FieldSource.INPUT
190+
and field_ref.path in filtered_args
191+
):
192+
del filtered_args[field_ref.path]
193+
modified = True
194+
195+
# Update the tool call with filtered args
196+
if isinstance(tool_call, dict):
197+
tool_call["args"] = filtered_args
198+
else:
199+
tool_call.args = filtered_args
200+
201+
break
202+
203+
if modified:
204+
ai_message.tool_calls = tool_calls
205+
return Command(update={"messages": msgs})
206+
207+
return {}
208+
209+
210+
def _filter_tool_output_fields(
211+
state: AgentGuardrailsGraphState,
212+
fields_to_filter: list[FieldReference],
213+
) -> dict[str, Any] | Command[Any]:
214+
"""Filter specified output fields from tool output (POST_EXECUTION only).
215+
216+
This function is called at POST_EXECUTION to filter output fields from tool results
217+
after the tool has been executed.
218+
219+
Args:
220+
state: The current agent graph state.
221+
fields_to_filter: List of FieldReference objects specifying which fields to filter.
222+
223+
Returns:
224+
Command to update messages with filtered tool output, or empty dict if no output fields to filter.
225+
"""
226+
# Check if there are any output fields to filter
227+
has_output_fields = any(
228+
field_ref.source == FieldSource.OUTPUT for field_ref in fields_to_filter
229+
)
230+
231+
if not has_output_fields:
232+
return {}
233+
234+
msgs = state.messages.copy()
235+
if not msgs:
236+
return {}
237+
238+
last_message = msgs[-1]
239+
if not isinstance(last_message, ToolMessage):
240+
return {}
241+
242+
# Parse the tool output content
243+
import json
244+
245+
content = last_message.content
246+
if not content:
247+
return {}
248+
249+
# Try to parse the content as JSON or dict
250+
try:
251+
if isinstance(content, dict):
252+
output_data = content
253+
elif isinstance(content, str):
254+
try:
255+
output_data = json.loads(content)
256+
except json.JSONDecodeError:
257+
# Try to parse as Python literal (dict representation)
258+
import ast
259+
260+
try:
261+
output_data = ast.literal_eval(content)
262+
if not isinstance(output_data, dict):
263+
return {}
264+
except (ValueError, SyntaxError):
265+
return {}
266+
else:
267+
# Content is not JSON-parseable, can't filter specific fields
268+
return {}
269+
except Exception:
270+
return {}
271+
272+
if not isinstance(output_data, dict):
273+
return {}
274+
275+
# Filter out the specified fields
276+
filtered_output = output_data.copy()
277+
modified = False
278+
279+
for field_ref in fields_to_filter:
280+
# Only filter output fields
281+
if field_ref.source == FieldSource.OUTPUT and field_ref.path in filtered_output:
282+
del filtered_output[field_ref.path]
283+
modified = True
284+
285+
if modified:
286+
# Update the tool message content with filtered output
287+
last_message.content = json.dumps(filtered_output)
288+
return Command(update={"messages": msgs})
289+
290+
return {}

src/uipath_langchain/agent/guardrails/guardrails_factory.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AgentGuardrail,
1010
AgentGuardrailBlockAction,
1111
AgentGuardrailEscalateAction,
12+
AgentGuardrailFilterAction,
1213
AgentGuardrailLogAction,
1314
AgentGuardrailSeverityLevel,
1415
AgentNumberOperator,
@@ -29,9 +30,11 @@
2930
from uipath_langchain.agent.guardrails.actions import (
3031
BlockAction,
3132
EscalateAction,
33+
FilterAction,
3234
GuardrailAction,
3335
LogAction,
3436
)
37+
from uipath_langchain.agent.guardrails.utils import _sanitize_selector_tool_names
3538

3639

3740
def _assert_value_not_none(value: str | None, operator: AgentWordOperator) -> str:
@@ -191,18 +194,21 @@ def _convert_agent_custom_guardrail_to_deterministic(
191194
guardrail: The agent custom guardrail to convert.
192195
193196
Returns:
194-
A DeterministicGuardrail with converted rules.
197+
A DeterministicGuardrail with converted rules and sanitized selector.
195198
"""
196199
converted_rules = [
197200
_convert_agent_rule_to_deterministic(rule) for rule in guardrail.rules
198201
]
199202

203+
# Sanitize tool names in selector for Tool scope guardrails
204+
sanitized_selector = _sanitize_selector_tool_names(guardrail.selector)
205+
200206
return DeterministicGuardrail(
201207
id=guardrail.id,
202208
name=guardrail.name,
203209
description=guardrail.description,
204210
enabled_for_evals=guardrail.enabled_for_evals,
205-
selector=guardrail.selector,
211+
selector=sanitized_selector,
206212
guardrail_type="custom",
207213
rules=converted_rules,
208214
)
@@ -227,8 +233,7 @@ def build_guardrails_with_actions(
227233
if isinstance(guardrail, AgentUnknownGuardrail):
228234
continue
229235

230-
# Convert AgentCustomGuardrail to DeterministicGuardrail
231-
converted_guardrail: BaseGuardrail = guardrail
236+
converted_guardrail: BaseGuardrail
232237
if isinstance(guardrail, AgentCustomGuardrail):
233238
converted_guardrail = _convert_agent_custom_guardrail_to_deterministic(
234239
guardrail
@@ -246,6 +251,9 @@ def build_guardrails_with_actions(
246251
f"Found invalid scopes: {[scope.name for scope in non_tool_scopes]}. "
247252
f"Please configure this guardrail to use only TOOL scope."
248253
)
254+
else:
255+
converted_guardrail = guardrail
256+
_sanitize_selector_tool_names(converted_guardrail.selector)
249257

250258
action = guardrail.action
251259

@@ -276,4 +284,6 @@ def build_guardrails_with_actions(
276284
),
277285
)
278286
)
287+
elif isinstance(action, AgentGuardrailFilterAction):
288+
result.append((converted_guardrail, FilterAction(fields=action.fields)))
279289
return result

0 commit comments

Comments
 (0)