33from enum import Enum
44from typing import Any
55
6- from langchain .tools import ToolRuntime
76from langchain_core .messages import ToolMessage
8- from langchain_core .tools import StructuredTool
7+ from langchain_core .messages .tool import ToolCall
8+ from langchain_core .tools import BaseTool , StructuredTool
99from langgraph .types import Command , interrupt
1010from uipath .agent .models .agent import (
1111 AgentEscalationChannel ,
1717
1818from uipath_langchain .agent .react .jsonschema_pydantic_converter import create_model
1919
20+ from ..react .types import AgentGraphNode , AgentGraphState , AgentTerminationSource
21+ from .tool_node import ToolWrapperMixin
2022from .utils import sanitize_tool_name
2123
2224
@@ -27,7 +29,11 @@ class EscalationAction(str, Enum):
2729 END = "end"
2830
2931
30- def create_escalation_tool (resource : AgentEscalationResourceConfig ) -> StructuredTool :
32+ class StructuredToolWithWrapper (StructuredTool , ToolWrapperMixin ):
33+ pass
34+
35+
36+ def create_escalation_tool (resource : AgentEscalationResourceConfig ) -> BaseTool :
3137 """Uses interrupt() for Action Center human-in-the-loop."""
3238
3339 tool_name : str = f"escalate_{ sanitize_tool_name (resource .name )} "
@@ -50,9 +56,7 @@ def create_escalation_tool(resource: AgentEscalationResourceConfig) -> Structure
5056 output_schema = output_model .model_json_schema (),
5157 example_calls = channel .properties .example_calls ,
5258 )
53- async def escalation_tool_fn (
54- runtime : ToolRuntime , ** kwargs : Any
55- ) -> Command [Any ] | Any :
59+ async def escalation_tool_fn (** kwargs : Any ) -> dict [str , Any ]:
5660 task_title = channel .task_title or "Escalation Task"
5761
5862 result = interrupt (
@@ -73,23 +77,41 @@ async def escalation_tool_fn(
7377 escalation_action = getattr (result , "action" , None )
7478 escalation_output = getattr (result , "data" , {})
7579
76- outcome = (
80+ outcome_str = (
7781 channel .outcome_mapping .get (escalation_action )
7882 if channel .outcome_mapping and escalation_action
7983 else None
8084 )
85+ outcome = (
86+ EscalationAction (outcome_str ) if outcome_str else EscalationAction .CONTINUE
87+ )
8188
82- if outcome == EscalationAction .END :
83- output_detail = f"Escalation output: { escalation_output } "
84- termination_title = f"Agent run ended based on escalation outcome { outcome } with directive { escalation_action } "
85- from ..react .types import AgentGraphNode , AgentTerminationSource
89+ return {
90+ "action" : outcome ,
91+ "output" : escalation_output ,
92+ "escalation_action" : escalation_action ,
93+ }
94+
95+ async def escalation_wrapper (
96+ tool : BaseTool ,
97+ call : ToolCall ,
98+ state : AgentGraphState ,
99+ ) -> dict [str , Any ] | Command [Any ]:
100+ result = await tool .ainvoke (call ["args" ])
101+
102+ if result ["action" ] == EscalationAction .END :
103+ output_detail = f"Escalation output: { result ['output' ]} "
104+ termination_title = (
105+ f"Agent run ended based on escalation outcome { result ['action' ]} "
106+ f"with directive { result ['escalation_action' ]} "
107+ )
86108
87109 return Command (
88110 update = {
89111 "messages" : [
90112 ToolMessage (
91113 content = f"{ termination_title } . { output_detail } " ,
92- tool_call_id = runtime . tool_call_id ,
114+ tool_call_id = call [ "id" ] ,
93115 )
94116 ],
95117 "termination" : {
@@ -101,9 +123,9 @@ async def escalation_tool_fn(
101123 goto = AgentGraphNode .TERMINATE ,
102124 )
103125
104- return escalation_output
126+ return result [ "output" ]
105127
106- tool = StructuredTool (
128+ tool = StructuredToolWithWrapper (
107129 name = tool_name ,
108130 description = resource .description ,
109131 args_schema = input_model ,
@@ -115,5 +137,6 @@ async def escalation_tool_fn(
115137 "assignee" : assignee ,
116138 },
117139 )
140+ tool .set_tool_wrappers (awrapper = escalation_wrapper )
118141
119142 return tool
0 commit comments