11import re
22from typing import Any
33
4+ from langchain_core .messages import AIMessage , ToolMessage
5+ from langgraph .types import Command
6+ from uipath .core .guardrails .guardrails import FieldReference , FieldSource
47from uipath .platform .guardrails import BaseGuardrail , GuardrailScope
58from uipath .runtime .errors import UiPathErrorCategory , UiPathErrorCode
69
1417class FilterAction (GuardrailAction ):
1518 """Action that filters inputs/outputs on guardrail failure.
1619
17- For now, filtering is only supported for non-AGENT and non-LLM scopes.
18- If invoked for ``GuardrailScope.AGENT`` or ``GuardrailScope.LLM``, this action
19- raises an exception to indicate the operation is not supported yet.
20+ For Tool scope, this action removes specified fields from tool call arguments.
21+ For AGENT and LLM scopes, this action raises an exception as it's not supported yet.
2022 """
2123
24+ def __init__ (self , fields : list [FieldReference ] | None = None ):
25+ """Initialize FilterAction with fields to filter.
26+
27+ Args:
28+ fields: List of FieldReference objects specifying which fields to filter.
29+ """
30+ self .fields = fields or []
31+
2232 def action_node (
2333 self ,
2434 * ,
@@ -41,15 +51,240 @@ def action_node(
4151 raw_node_name = f"{ scope .name } _{ execution_stage .name } _{ guardrail .name } _filter"
4252 node_name = re .sub (r"\W+" , "_" , raw_node_name .lower ()).strip ("_" )
4353
44- async def _node (_state : AgentGuardrailsGraphState ) -> dict [str , Any ]:
45- if scope in (GuardrailScope .AGENT , GuardrailScope .LLM ):
46- raise AgentTerminationException (
47- code = UiPathErrorCode .EXECUTION_ERROR ,
48- title = "Guardrail filter action not supported" ,
49- detail = f"FilterAction is not supported for scope [{ scope .name } ] at this time." ,
50- category = UiPathErrorCategory .USER ,
54+ async def _node (
55+ _state : AgentGuardrailsGraphState ,
56+ ) -> dict [str , Any ] | Command [Any ]:
57+ if scope == GuardrailScope .TOOL :
58+ return _filter_tool_fields (
59+ _state ,
60+ self .fields ,
61+ execution_stage ,
62+ guarded_component_name ,
63+ guardrail .name ,
5164 )
52- # No-op for other scopes for now.
53- return {}
65+
66+ raise AgentTerminationException (
67+ code = UiPathErrorCode .EXECUTION_ERROR ,
68+ title = "Guardrail filter action not supported" ,
69+ detail = f"FilterAction is not supported for scope [{ scope .name } ] at this time." ,
70+ category = UiPathErrorCategory .USER ,
71+ )
5472
5573 return node_name , _node
74+
75+
76+ def _filter_tool_fields (
77+ state : AgentGuardrailsGraphState ,
78+ fields_to_filter : list [FieldReference ],
79+ execution_stage : ExecutionStage ,
80+ tool_name : str ,
81+ guardrail_name : str ,
82+ ) -> dict [str , Any ] | Command [Any ]:
83+ """Filter specified fields from tool call arguments or tool output.
84+
85+ The filter action filters fields based on the execution stage:
86+ - PRE_EXECUTION: Only input fields are filtered
87+ - POST_EXECUTION: Only output fields are filtered
88+
89+ Args:
90+ state: The current agent graph state.
91+ fields_to_filter: List of FieldReference objects specifying which fields to filter.
92+ execution_stage: The execution stage (PRE_EXECUTION or POST_EXECUTION).
93+ tool_name: Name of the tool to filter.
94+ guardrail_name: Name of the guardrail for logging purposes.
95+
96+ Returns:
97+ Command to update messages with filtered tool call args or output.
98+
99+ Raises:
100+ AgentTerminationException: If filtering fails.
101+ """
102+ try :
103+ if not fields_to_filter :
104+ return {}
105+
106+ if execution_stage == ExecutionStage .PRE_EXECUTION :
107+ return _filter_tool_input_fields (state , fields_to_filter , tool_name )
108+ else :
109+ return _filter_tool_output_fields (state , fields_to_filter )
110+
111+ except Exception as e :
112+ raise AgentTerminationException (
113+ code = UiPathErrorCode .EXECUTION_ERROR ,
114+ title = "Filter action failed" ,
115+ detail = f"Failed to filter tool fields: { str (e )} " ,
116+ category = UiPathErrorCategory .USER ,
117+ ) from e
118+
119+
120+ def _filter_tool_input_fields (
121+ state : AgentGuardrailsGraphState ,
122+ fields_to_filter : list [FieldReference ],
123+ tool_name : str ,
124+ ) -> dict [str , Any ] | Command [Any ]:
125+ """Filter specified input fields from tool call arguments (PRE_EXECUTION only).
126+
127+ This function is called at PRE_EXECUTION to filter input fields from tool call arguments
128+ before the tool is executed.
129+
130+ Args:
131+ state: The current agent graph state.
132+ fields_to_filter: List of FieldReference objects specifying which fields to filter.
133+ tool_name: Name of the tool to filter.
134+
135+ Returns:
136+ Command to update messages with filtered tool call args, or empty dict if no input fields to filter.
137+ """
138+ # Check if there are any input fields to filter
139+ has_input_fields = any (
140+ field_ref .source == FieldSource .INPUT for field_ref in fields_to_filter
141+ )
142+
143+ if not has_input_fields :
144+ return {}
145+
146+ msgs = state .messages .copy ()
147+ if not msgs :
148+ return {}
149+
150+ # Find the AIMessage with tool calls
151+ # At PRE_EXECUTION, this is always the last message
152+ ai_message = None
153+ for i in range (len (msgs ) - 1 , - 1 , - 1 ):
154+ msg = msgs [i ]
155+ if isinstance (msg , AIMessage ) and msg .tool_calls :
156+ ai_message = msg
157+ break
158+
159+ if ai_message is None :
160+ return {}
161+
162+ # Find and filter the tool call with matching name
163+ # Type assertion: we know ai_message is AIMessage from the check above
164+ assert isinstance (ai_message , AIMessage )
165+ tool_calls = list (ai_message .tool_calls )
166+ modified = False
167+
168+ for tool_call in tool_calls :
169+ call_name = (
170+ tool_call .get ("name" )
171+ if isinstance (tool_call , dict )
172+ else getattr (tool_call , "name" , None )
173+ )
174+
175+ if call_name == tool_name :
176+ # Get the current args
177+ args = (
178+ tool_call .get ("args" )
179+ if isinstance (tool_call , dict )
180+ else getattr (tool_call , "args" , None )
181+ )
182+
183+ if args and isinstance (args , dict ):
184+ # Filter out the specified input fields
185+ filtered_args = args .copy ()
186+ for field_ref in fields_to_filter :
187+ # Only filter input fields
188+ if (
189+ field_ref .source == FieldSource .INPUT
190+ and field_ref .path in filtered_args
191+ ):
192+ del filtered_args [field_ref .path ]
193+ modified = True
194+
195+ # Update the tool call with filtered args
196+ if isinstance (tool_call , dict ):
197+ tool_call ["args" ] = filtered_args
198+ else :
199+ tool_call .args = filtered_args
200+
201+ break
202+
203+ if modified :
204+ ai_message .tool_calls = tool_calls
205+ return Command (update = {"messages" : msgs })
206+
207+ return {}
208+
209+
210+ def _filter_tool_output_fields (
211+ state : AgentGuardrailsGraphState ,
212+ fields_to_filter : list [FieldReference ],
213+ ) -> dict [str , Any ] | Command [Any ]:
214+ """Filter specified output fields from tool output (POST_EXECUTION only).
215+
216+ This function is called at POST_EXECUTION to filter output fields from tool results
217+ after the tool has been executed.
218+
219+ Args:
220+ state: The current agent graph state.
221+ fields_to_filter: List of FieldReference objects specifying which fields to filter.
222+
223+ Returns:
224+ Command to update messages with filtered tool output, or empty dict if no output fields to filter.
225+ """
226+ # Check if there are any output fields to filter
227+ has_output_fields = any (
228+ field_ref .source == FieldSource .OUTPUT for field_ref in fields_to_filter
229+ )
230+
231+ if not has_output_fields :
232+ return {}
233+
234+ msgs = state .messages .copy ()
235+ if not msgs :
236+ return {}
237+
238+ last_message = msgs [- 1 ]
239+ if not isinstance (last_message , ToolMessage ):
240+ return {}
241+
242+ # Parse the tool output content
243+ import json
244+
245+ content = last_message .content
246+ if not content :
247+ return {}
248+
249+ # Try to parse the content as JSON or dict
250+ try :
251+ if isinstance (content , dict ):
252+ output_data = content
253+ elif isinstance (content , str ):
254+ try :
255+ output_data = json .loads (content )
256+ except json .JSONDecodeError :
257+ # Try to parse as Python literal (dict representation)
258+ import ast
259+
260+ try :
261+ output_data = ast .literal_eval (content )
262+ if not isinstance (output_data , dict ):
263+ return {}
264+ except (ValueError , SyntaxError ):
265+ return {}
266+ else :
267+ # Content is not JSON-parseable, can't filter specific fields
268+ return {}
269+ except Exception :
270+ return {}
271+
272+ if not isinstance (output_data , dict ):
273+ return {}
274+
275+ # Filter out the specified fields
276+ filtered_output = output_data .copy ()
277+ modified = False
278+
279+ for field_ref in fields_to_filter :
280+ # Only filter output fields
281+ if field_ref .source == FieldSource .OUTPUT and field_ref .path in filtered_output :
282+ del filtered_output [field_ref .path ]
283+ modified = True
284+
285+ if modified :
286+ # Update the tool message content with filtered output
287+ last_message .content = json .dumps (filtered_output )
288+ return Command (update = {"messages" : msgs })
289+
290+ return {}
0 commit comments