Skip to content

Commit a55a4e3

Browse files
committed
fix: decouple runtime from graph config
1 parent 1a179b3 commit a55a4e3

File tree

7 files changed

+157
-114
lines changed

7 files changed

+157
-114
lines changed

src/uipath_langchain/_cli/_runtime/_context.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
from typing import Any, Optional, Union
22

33
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
4-
from langgraph.graph import StateGraph
54
from uipath._cli._runtime._contracts import UiPathRuntimeContext
65

7-
from .._utils._graph import LangGraphConfig
8-
96

107
class LangGraphRuntimeContext(UiPathRuntimeContext):
118
"""Context information passed throughout the runtime execution."""
129

13-
langgraph_config: Optional[LangGraphConfig] = None
14-
state_graph: Optional[StateGraph[Any, Any]] = None
1510
output: Optional[Any] = None
1611
state: Optional[Any] = (
1712
None # TypedDict issue, the actual type is: Optional[langgraph.types.StateSnapshot]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from typing import Any, Optional
2+
3+
from langgraph.graph.state import CompiledStateGraph, StateGraph
4+
from uipath._cli._runtime._contracts import (
5+
UiPathErrorCategory,
6+
)
7+
8+
from .._utils._graph import GraphConfig, LangGraphConfig
9+
from ._exception import LangGraphRuntimeError
10+
11+
12+
class GraphResolver:
13+
def __init__(self, entrypoint: Optional[str] = None) -> None:
14+
self.entrypoint = entrypoint
15+
self.graph_config: Optional[GraphConfig] = None
16+
17+
async def __call__(self) -> StateGraph[Any, Any, Any]:
18+
return await self._resolve(self.entrypoint)
19+
20+
async def _resolve(self, entrypoint: Optional[str]) -> StateGraph[Any, Any, Any]:
21+
config = LangGraphConfig()
22+
if not config.exists:
23+
raise LangGraphRuntimeError(
24+
"CONFIG_MISSING",
25+
"Invalid configuration",
26+
"Failed to load configuration",
27+
UiPathErrorCategory.DEPLOYMENT,
28+
)
29+
30+
try:
31+
config.load_config()
32+
except Exception as e:
33+
raise LangGraphRuntimeError(
34+
"CONFIG_INVALID",
35+
"Invalid configuration",
36+
f"Failed to load configuration: {str(e)}",
37+
UiPathErrorCategory.DEPLOYMENT,
38+
) from e
39+
40+
# Determine entrypoint if not provided
41+
graphs = config.graphs
42+
if not entrypoint and len(graphs) == 1:
43+
entrypoint = graphs[0].name
44+
elif not entrypoint:
45+
graph_names = ", ".join(g.name for g in graphs)
46+
raise LangGraphRuntimeError(
47+
"ENTRYPOINT_MISSING",
48+
"Entrypoint required",
49+
f"Multiple graphs available. Please specify one of: {graph_names}.",
50+
UiPathErrorCategory.DEPLOYMENT,
51+
)
52+
53+
# Get the specified graph
54+
self.graph_config = config.get_graph(entrypoint)
55+
if not self.graph_config:
56+
raise LangGraphRuntimeError(
57+
"GRAPH_NOT_FOUND",
58+
"Graph not found",
59+
f"Graph '{entrypoint}' not found.",
60+
UiPathErrorCategory.DEPLOYMENT,
61+
)
62+
try:
63+
loaded_graph = await self.graph_config.load_graph()
64+
return (
65+
loaded_graph.builder
66+
if isinstance(loaded_graph, CompiledStateGraph)
67+
else loaded_graph
68+
)
69+
except ImportError as e:
70+
raise LangGraphRuntimeError(
71+
"GRAPH_IMPORT_ERROR",
72+
"Graph import failed",
73+
f"Failed to import graph '{entrypoint}': {str(e)}",
74+
UiPathErrorCategory.USER,
75+
) from e
76+
except TypeError as e:
77+
raise LangGraphRuntimeError(
78+
"GRAPH_TYPE_ERROR",
79+
"Invalid graph type",
80+
f"Graph '{entrypoint}' is not a valid StateGraph or CompiledStateGraph: {str(e)}",
81+
UiPathErrorCategory.USER,
82+
) from e
83+
except ValueError as e:
84+
raise LangGraphRuntimeError(
85+
"GRAPH_VALUE_ERROR",
86+
"Invalid graph value",
87+
f"Invalid value in graph '{entrypoint}': {str(e)}",
88+
UiPathErrorCategory.USER,
89+
) from e
90+
except Exception as e:
91+
raise LangGraphRuntimeError(
92+
"GRAPH_LOAD_ERROR",
93+
"Failed to load graph",
94+
f"Unexpected error loading graph '{entrypoint}': {str(e)}",
95+
UiPathErrorCategory.USER,
96+
) from e
97+
98+
async def cleanup(self):
99+
"""Clean up resources"""
100+
if self.graph_config:
101+
await self.graph_config.cleanup()
102+
self.graph_config = None

src/uipath_langchain/_cli/_runtime/_runtime.py

Lines changed: 35 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,42 @@
11
import json
22
import logging
33
import os
4-
from typing import Any, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
55

66
from langchain_core.callbacks.base import BaseCallbackHandler
77
from langchain_core.messages import BaseMessage
88
from langchain_core.runnables.config import RunnableConfig
99
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
1010
from langgraph.errors import EmptyInputError, GraphRecursionError, InvalidUpdateError
11-
from langgraph.graph.state import CompiledStateGraph
11+
from langgraph.graph.state import CompiledStateGraph, StateGraph
1212
from uipath._cli._runtime._contracts import (
1313
UiPathBaseRuntime,
1414
UiPathErrorCategory,
1515
UiPathRuntimeResult,
1616
)
1717

18-
from .._utils._graph import LangGraphConfig
1918
from ._context import LangGraphRuntimeContext
2019
from ._conversation import map_message
2120
from ._exception import LangGraphRuntimeError
21+
from ._graph_resolver import GraphResolver
2222
from ._input import LangGraphInputProcessor
2323
from ._output import LangGraphOutputProcessor
2424

2525
logger = logging.getLogger(__name__)
2626

27+
AsyncResolver = Callable[[], Awaitable[StateGraph[Any, Any, Any]]]
28+
2729

2830
class LangGraphRuntime(UiPathBaseRuntime):
2931
"""
3032
A runtime class implementing the async context manager protocol.
3133
This allows using the class with 'async with' statements.
3234
"""
3335

34-
def __init__(self, context: LangGraphRuntimeContext):
36+
def __init__(self, context: LangGraphRuntimeContext, graph_resolver: AsyncResolver):
3537
super().__init__(context)
3638
self.context: LangGraphRuntimeContext = context
39+
self.graph_resolver: AsyncResolver = graph_resolver
3740

3841
async def execute(self) -> Optional[UiPathRuntimeResult]:
3942
"""
@@ -46,7 +49,8 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
4649
LangGraphRuntimeError: If execution fails
4750
"""
4851

49-
if self.context.state_graph is None:
52+
graph = await self.graph_resolver()
53+
if not graph:
5054
return None
5155

5256
try:
@@ -56,9 +60,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
5660
self.context.memory = memory
5761

5862
# Compile the graph with the checkpointer
59-
graph = self.context.state_graph.compile(
60-
checkpointer=self.context.memory
61-
)
63+
compiled_graph = graph.compile(checkpointer=self.context.memory)
6264

6365
# Process input, handling resume if needed
6466
input_processor = LangGraphInputProcessor(context=self.context)
@@ -87,7 +89,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
8789
graph_config["max_concurrency"] = int(max_concurrency)
8890

8991
if self.context.chat_handler:
90-
async for stream_chunk in graph.astream(
92+
async for stream_chunk in compiled_graph.astream(
9193
processed_input,
9294
graph_config,
9395
stream_mode="messages",
@@ -109,7 +111,7 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
109111
elif self.is_debug_run():
110112
# Get final chunk while streaming
111113
final_chunk = None
112-
async for stream_chunk in graph.astream(
114+
async for stream_chunk in compiled_graph.astream(
113115
processed_input,
114116
graph_config,
115117
stream_mode="updates",
@@ -118,16 +120,18 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
118120
self._pretty_print(stream_chunk)
119121
final_chunk = stream_chunk
120122

121-
self.context.output = self._extract_graph_result(final_chunk, graph)
123+
self.context.output = self._extract_graph_result(
124+
final_chunk, compiled_graph
125+
)
122126
else:
123127
# Execute the graph normally at runtime or eval
124-
self.context.output = await graph.ainvoke(
128+
self.context.output = await compiled_graph.ainvoke(
125129
processed_input, graph_config
126130
)
127131

128132
# Get the state if available
129133
try:
130-
self.context.state = await graph.aget_state(graph_config)
134+
self.context.state = await compiled_graph.aget_state(graph_config)
131135
except Exception:
132136
pass
133137

@@ -177,91 +181,10 @@ async def execute(self) -> Optional[UiPathRuntimeResult]:
177181
pass
178182

179183
async def validate(self) -> None:
180-
"""Validate runtime inputs."""
181-
"""Load and validate the graph configuration ."""
182-
if self.context.langgraph_config is None:
183-
self.context.langgraph_config = LangGraphConfig()
184-
if not self.context.langgraph_config.exists:
185-
raise LangGraphRuntimeError(
186-
"CONFIG_MISSING",
187-
"Invalid configuration",
188-
"Failed to load configuration",
189-
UiPathErrorCategory.DEPLOYMENT,
190-
)
191-
192-
try:
193-
self.context.langgraph_config.load_config()
194-
except Exception as e:
195-
raise LangGraphRuntimeError(
196-
"CONFIG_INVALID",
197-
"Invalid configuration",
198-
f"Failed to load configuration: {str(e)}",
199-
UiPathErrorCategory.DEPLOYMENT,
200-
) from e
201-
202-
# Determine entrypoint if not provided
203-
graphs = self.context.langgraph_config.graphs
204-
if not self.context.entrypoint and len(graphs) == 1:
205-
self.context.entrypoint = graphs[0].name
206-
elif not self.context.entrypoint:
207-
graph_names = ", ".join(g.name for g in graphs)
208-
raise LangGraphRuntimeError(
209-
"ENTRYPOINT_MISSING",
210-
"Entrypoint required",
211-
f"Multiple graphs available. Please specify one of: {graph_names}.",
212-
UiPathErrorCategory.DEPLOYMENT,
213-
)
214-
215-
# Get the specified graph
216-
self.graph_config = self.context.langgraph_config.get_graph(
217-
self.context.entrypoint
218-
)
219-
if not self.graph_config:
220-
raise LangGraphRuntimeError(
221-
"GRAPH_NOT_FOUND",
222-
"Graph not found",
223-
f"Graph '{self.context.entrypoint}' not found.",
224-
UiPathErrorCategory.DEPLOYMENT,
225-
)
226-
try:
227-
loaded_graph = await self.graph_config.load_graph()
228-
self.context.state_graph = (
229-
loaded_graph.builder
230-
if isinstance(loaded_graph, CompiledStateGraph)
231-
else loaded_graph
232-
)
233-
except ImportError as e:
234-
raise LangGraphRuntimeError(
235-
"GRAPH_IMPORT_ERROR",
236-
"Graph import failed",
237-
f"Failed to import graph '{self.context.entrypoint}': {str(e)}",
238-
UiPathErrorCategory.USER,
239-
) from e
240-
except TypeError as e:
241-
raise LangGraphRuntimeError(
242-
"GRAPH_TYPE_ERROR",
243-
"Invalid graph type",
244-
f"Graph '{self.context.entrypoint}' is not a valid StateGraph or CompiledStateGraph: {str(e)}",
245-
UiPathErrorCategory.USER,
246-
) from e
247-
except ValueError as e:
248-
raise LangGraphRuntimeError(
249-
"GRAPH_VALUE_ERROR",
250-
"Invalid graph value",
251-
f"Invalid value in graph '{self.context.entrypoint}': {str(e)}",
252-
UiPathErrorCategory.USER,
253-
) from e
254-
except Exception as e:
255-
raise LangGraphRuntimeError(
256-
"GRAPH_LOAD_ERROR",
257-
"Failed to load graph",
258-
f"Unexpected error loading graph '{self.context.entrypoint}': {str(e)}",
259-
UiPathErrorCategory.USER,
260-
) from e
184+
pass
261185

262186
async def cleanup(self):
263-
if hasattr(self, "graph_config") and self.graph_config:
264-
await self.graph_config.cleanup()
187+
pass
265188

266189
def _extract_graph_result(
267190
self, final_chunk, graph: CompiledStateGraph[Any, Any, Any]
@@ -377,3 +300,19 @@ def _pretty_print(self, stream_chunk: Union[Tuple[Any, Any], Dict[str, Any], Any
377300
logger.info("%s", formatted_metadata)
378301
except (TypeError, ValueError):
379302
pass
303+
304+
305+
class LangGraphScriptRuntime(LangGraphRuntime):
306+
"""
307+
Resolves the graph from langgraph.json config file and passes it to the base runtime.
308+
"""
309+
310+
def __init__(
311+
self, context: LangGraphRuntimeContext, entrypoint: Optional[str] = None
312+
):
313+
self.resolver = GraphResolver(entrypoint=entrypoint)
314+
super().__init__(context, self.resolver)
315+
316+
async def cleanup(self):
317+
await super().cleanup()
318+
await self.resolver.cleanup()

src/uipath_langchain/_cli/cli_dev.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .._tracing import _instrument_traceable_attributes
1414
from ._runtime._context import LangGraphRuntimeContext
15-
from ._runtime._runtime import LangGraphRuntime
15+
from ._runtime._runtime import LangGraphScriptRuntime
1616

1717
console = ConsoleLogger()
1818

@@ -22,8 +22,14 @@ def langgraph_dev_middleware(interface: Optional[str]) -> MiddlewareResult:
2222

2323
try:
2424
if interface == "terminal":
25+
26+
def generate_runtime(
27+
ctx: LangGraphRuntimeContext,
28+
) -> LangGraphScriptRuntime:
29+
return LangGraphScriptRuntime(ctx, ctx.entrypoint)
30+
2531
runtime_factory = UiPathRuntimeFactory(
26-
LangGraphRuntime, LangGraphRuntimeContext
32+
LangGraphScriptRuntime, LangGraphRuntimeContext, generate_runtime
2733
)
2834

2935
_instrument_traceable_attributes()

0 commit comments

Comments
 (0)