Skip to content

Commit 3567f60

Browse files
Merge pull request #381 from UiPath/feat/tool-runtime-injection
fix: escalation tool missing runtime parameter
2 parents 76e18ff + e207324 commit 3567f60

File tree

5 files changed

+46
-22
lines changed

5 files changed

+46
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.3.0"
3+
version = "0.3.1"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/agent/tools/escalation_tool.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from enum import Enum
44
from typing import Any
55

6-
from langchain.tools import ToolRuntime
76
from langchain_core.messages import ToolMessage
8-
from langchain_core.tools import StructuredTool
7+
from langchain_core.messages.tool import ToolCall
8+
from langchain_core.tools import BaseTool, StructuredTool
99
from langgraph.types import Command, interrupt
1010
from uipath.agent.models.agent import (
1111
AgentEscalationChannel,
@@ -17,6 +17,8 @@
1717

1818
from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model
1919

20+
from ..react.types import AgentGraphNode, AgentGraphState, AgentTerminationSource
21+
from .tool_node import ToolWrapperMixin
2022
from .utils import sanitize_tool_name
2123

2224

@@ -27,7 +29,11 @@ class EscalationAction(str, Enum):
2729
END = "end"
2830

2931

30-
def create_escalation_tool(resource: AgentEscalationResourceConfig) -> StructuredTool:
32+
class StructuredToolWithWrapper(StructuredTool, ToolWrapperMixin):
33+
pass
34+
35+
36+
def create_escalation_tool(resource: AgentEscalationResourceConfig) -> BaseTool:
3137
"""Uses interrupt() for Action Center human-in-the-loop."""
3238

3339
tool_name: str = f"escalate_{sanitize_tool_name(resource.name)}"
@@ -50,9 +56,7 @@ def create_escalation_tool(resource: AgentEscalationResourceConfig) -> Structure
5056
output_schema=output_model.model_json_schema(),
5157
example_calls=channel.properties.example_calls,
5258
)
53-
async def escalation_tool_fn(
54-
runtime: ToolRuntime, **kwargs: Any
55-
) -> Command[Any] | Any:
59+
async def escalation_tool_fn(**kwargs: Any) -> dict[str, Any]:
5660
task_title = channel.task_title or "Escalation Task"
5761

5862
result = interrupt(
@@ -73,23 +77,41 @@ async def escalation_tool_fn(
7377
escalation_action = getattr(result, "action", None)
7478
escalation_output = getattr(result, "data", {})
7579

76-
outcome = (
80+
outcome_str = (
7781
channel.outcome_mapping.get(escalation_action)
7882
if channel.outcome_mapping and escalation_action
7983
else None
8084
)
85+
outcome = (
86+
EscalationAction(outcome_str) if outcome_str else EscalationAction.CONTINUE
87+
)
8188

82-
if outcome == EscalationAction.END:
83-
output_detail = f"Escalation output: {escalation_output}"
84-
termination_title = f"Agent run ended based on escalation outcome {outcome} with directive {escalation_action}"
85-
from ..react.types import AgentGraphNode, AgentTerminationSource
89+
return {
90+
"action": outcome,
91+
"output": escalation_output,
92+
"escalation_action": escalation_action,
93+
}
94+
95+
async def escalation_wrapper(
96+
tool: BaseTool,
97+
call: ToolCall,
98+
state: AgentGraphState,
99+
) -> dict[str, Any] | Command[Any]:
100+
result = await tool.ainvoke(call["args"])
101+
102+
if result["action"] == EscalationAction.END:
103+
output_detail = f"Escalation output: {result['output']}"
104+
termination_title = (
105+
f"Agent run ended based on escalation outcome {result['action']} "
106+
f"with directive {result['escalation_action']}"
107+
)
86108

87109
return Command(
88110
update={
89111
"messages": [
90112
ToolMessage(
91113
content=f"{termination_title}. {output_detail}",
92-
tool_call_id=runtime.tool_call_id,
114+
tool_call_id=call["id"],
93115
)
94116
],
95117
"termination": {
@@ -101,9 +123,9 @@ async def escalation_tool_fn(
101123
goto=AgentGraphNode.TERMINATE,
102124
)
103125

104-
return escalation_output
126+
return result["output"]
105127

106-
tool = StructuredTool(
128+
tool = StructuredToolWithWrapper(
107129
name=tool_name,
108130
description=resource.description,
109131
args_schema=input_model,
@@ -115,5 +137,6 @@ async def escalation_tool_fn(
115137
"assignee": assignee,
116138
},
117139
)
140+
tool.set_tool_wrappers(awrapper=escalation_wrapper)
118141

119142
return tool

src/uipath_langchain/agent/tools/tool_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Factory functions for creating tools from agent resources."""
22

33
from langchain_core.language_models import BaseChatModel
4-
from langchain_core.tools import BaseTool, StructuredTool
4+
from langchain_core.tools import BaseTool
55
from uipath.agent.models.agent import (
66
AgentContextResourceConfig,
77
AgentEscalationResourceConfig,
@@ -34,7 +34,7 @@ async def create_tools_from_resources(
3434

3535
async def _build_tool_for_resource(
3636
resource: BaseAgentResourceConfig, llm: BaseChatModel
37-
) -> StructuredTool | None:
37+
) -> BaseTool | None:
3838
if isinstance(resource, AgentProcessToolResourceConfig):
3939
return create_process_tool(resource)
4040

src/uipath_langchain/agent/tools/tool_node.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from langchain_core.messages.ai import AIMessage
88
from langchain_core.messages.tool import ToolCall, ToolMessage
9+
from langchain_core.runnables.config import RunnableConfig
910
from langchain_core.tools import BaseTool
1011
from langgraph._internal._runnable import RunnableCallable
1112
from langgraph.types import Command
@@ -48,7 +49,7 @@ def __init__(
4849
self.wrapper = wrapper
4950
self.awrapper = awrapper
5051

51-
def _func(self, state: Any) -> OutputType:
52+
def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType:
5253
call = self._extract_tool_call(state)
5354
if call is None:
5455
return None
@@ -57,10 +58,11 @@ def _func(self, state: Any) -> OutputType:
5758
result = self.wrapper(self.tool, call, filtered_state)
5859
else:
5960
result = self.tool.invoke(call["args"])
60-
6161
return self._process_result(call, result)
6262

63-
async def _afunc(self, state: Any) -> OutputType:
63+
async def _afunc(
64+
self, state: Any, config: RunnableConfig | None = None
65+
) -> OutputType:
6466
call = self._extract_tool_call(state)
6567
if call is None:
6668
return None
@@ -69,7 +71,6 @@ async def _afunc(self, state: Any) -> OutputType:
6971
result = await self.awrapper(self.tool, call, filtered_state)
7072
else:
7173
result = await self.tool.ainvoke(call["args"])
72-
7374
return self._process_result(call, result)
7475

7576
def _extract_tool_call(self, state: Any) -> ToolCall | None:

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)