|
10 | 10 | GuardrailScope, |
11 | 11 | ) |
12 | 12 |
|
13 | | -from uipath_langchain.agent.guardrails.types import ExecutionStage |
14 | | - |
15 | | -from .actions.base_action import GuardrailAction, GuardrailActionNode |
16 | | -from .guardrail_nodes import ( |
17 | | - create_agent_guardrail_node, |
| 13 | +from uipath_langchain.agent.guardrails.actions.base_action import ( |
| 14 | + GuardrailAction, |
| 15 | + GuardrailActionNode, |
| 16 | +) |
| 17 | +from uipath_langchain.agent.guardrails.guardrail_nodes import ( |
| 18 | + create_agent_init_guardrail_node, |
| 19 | + create_agent_terminate_guardrail_node, |
18 | 20 | create_llm_guardrail_node, |
19 | 21 | create_tool_guardrail_node, |
20 | 22 | ) |
21 | | -from .types import AgentGuardrailsGraphState |
| 23 | +from uipath_langchain.agent.guardrails.types import ExecutionStage |
| 24 | +from uipath_langchain.agent.react.types import ( |
| 25 | + AgentGraphState, |
| 26 | + AgentGuardrailsGraphState, |
| 27 | +) |
22 | 28 |
|
23 | 29 | _VALIDATOR_ALLOWED_STAGES = { |
24 | 30 | "prompt_injection": {ExecutionStage.PRE_EXECUTION}, |
@@ -232,32 +238,65 @@ def create_tools_guardrails_subgraph( |
232 | 238 | return result |
233 | 239 |
|
234 | 240 |
|
235 | | -def create_agent_guardrails_subgraph( |
236 | | - agent_node: tuple[str, Any], |
| 241 | +def create_agent_init_guardrails_subgraph( |
| 242 | + init_node: tuple[str, Any], |
237 | 243 | guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None, |
238 | | - execution_stage: ExecutionStage, |
239 | 244 | ): |
240 | | - """Create a subgraph for AGENT-scoped guardrails that applies checks at the specified stage. |
241 | | -
|
242 | | - This is intended for wrapping nodes like INIT or TERMINATE, where guardrails should run |
243 | | - either before (pre-execution) or after (post-execution) the node logic. |
244 | | - """ |
| 245 | + """Create a subgraph for INIT node that applies guardrails on the state messages.""" |
245 | 246 | applicable_guardrails = [ |
246 | 247 | (guardrail, _) |
247 | 248 | for (guardrail, _) in (guardrails or []) |
248 | 249 | if GuardrailScope.AGENT in guardrail.selector.scopes |
249 | 250 | ] |
250 | 251 | if applicable_guardrails is None or len(applicable_guardrails) == 0: |
251 | | - return agent_node[1] |
| 252 | + return init_node[1] |
252 | 253 |
|
253 | 254 | return _create_guardrails_subgraph( |
254 | | - main_inner_node=agent_node, |
| 255 | + main_inner_node=init_node, |
| 256 | + guardrails=applicable_guardrails, |
| 257 | + scope=GuardrailScope.AGENT, |
| 258 | + execution_stages=[ExecutionStage.POST_EXECUTION], |
| 259 | + node_factory=create_agent_init_guardrail_node, |
| 260 | + ) |
| 261 | + |
| 262 | + |
| 263 | +def create_agent_terminate_guardrails_subgraph( |
| 264 | + terminate_node: tuple[str, Any], |
| 265 | + guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None, |
| 266 | +): |
| 267 | + """Create a subgraph for TERMINATE node that applies guardrails on the agent result.""" |
| 268 | + node_name, node_func = terminate_node |
| 269 | + |
| 270 | + def terminate_wrapper(state: Any) -> dict[str, Any]: |
| 271 | + # Call original terminate node |
| 272 | + result = node_func(state) |
| 273 | + # Store result in state |
| 274 | + return {"agent_result": result, "messages": state.messages} |
| 275 | + |
| 276 | + applicable_guardrails = [ |
| 277 | + (guardrail, _) |
| 278 | + for (guardrail, _) in (guardrails or []) |
| 279 | + if GuardrailScope.AGENT in guardrail.selector.scopes |
| 280 | + ] |
| 281 | + if applicable_guardrails is None or len(applicable_guardrails) == 0: |
| 282 | + return terminate_node[1] |
| 283 | + |
| 284 | + subgraph = _create_guardrails_subgraph( |
| 285 | + main_inner_node=(node_name, terminate_wrapper), |
255 | 286 | guardrails=applicable_guardrails, |
256 | 287 | scope=GuardrailScope.AGENT, |
257 | | - execution_stages=[execution_stage], |
258 | | - node_factory=create_agent_guardrail_node, |
| 288 | + execution_stages=[ExecutionStage.POST_EXECUTION], |
| 289 | + node_factory=create_agent_terminate_guardrail_node, |
259 | 290 | ) |
260 | 291 |
|
| 292 | + async def run_terminate_subgraph( |
| 293 | + state: AgentGraphState, |
| 294 | + ) -> dict[str, Any]: |
| 295 | + result_state = await subgraph.ainvoke(state) |
| 296 | + return result_state["agent_result"] |
| 297 | + |
| 298 | + return run_terminate_subgraph |
| 299 | + |
261 | 300 |
|
262 | 301 | def create_tool_guardrails_subgraph( |
263 | 302 | tool_node: tuple[str, Any], |
|
0 commit comments