Skip to content

Commit 980f781

Browse files
mfateevjssmith
authored andcommitted
Added support for LangSmith tracing across workflows and activities (#188)
* Added langchain tracing interceptor. Changed sample to use a child workflow and multiple activities.
1 parent e38917c commit 980f781

File tree

8 files changed

+249
-17
lines changed

8 files changed

+249
-17
lines changed

langchain/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# LangChain Sample
22

3-
This sample shows you how you can use Temporal to orchestrate workflows for [LangChain](https://www.langchain.com).
3+
This sample shows you how you can use Temporal to orchestrate workflows for [LangChain](https://www.langchain.com). It includes an interceptor that makes LangSmith traces work seamlessly across Temporal clients, workflows and activities.
44

55
For this sample, the optional `langchain` dependency group must be included. To include, run:
66

@@ -21,8 +21,10 @@ This will start the worker. Then, in another terminal, run the following to exec
2121

2222
Then, in another terminal, run the following command to translate a phrase:
2323

24-
curl -X POST "http://localhost:8000/translate?phrase=hello%20world&language=Spanish"
24+
curl -X POST "http://localhost:8000/translate?phrase=hello%20world&language1=Spanish&language2=French&language3=Russian"
2525

2626
Which should produce some output like:
2727

28-
{"translation":"Hola mundo"}
28+
{"translations":{"French":"Bonjour tout le monde","Russian":"Привет, мир","Spanish":"Hola mundo"}}
29+
30+
Check [LangSmith](https://smith.langchain.com/) for the corresponding trace.

langchain/activities.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TranslateParams:
1313

1414

1515
@activity.defn
16-
async def translate_phrase(params: TranslateParams) -> dict:
16+
async def translate_phrase(params: TranslateParams) -> str:
1717
# LangChain setup
1818
template = """You are a helpful assistant who translates between languages.
1919
Translate the following phrase into the specified language: {phrase}
@@ -26,6 +26,9 @@ async def translate_phrase(params: TranslateParams) -> dict:
2626
)
2727
chain = chat_prompt | ChatOpenAI()
2828
# Use the asynchronous invoke method
29-
return dict(
30-
await chain.ainvoke({"phrase": params.phrase, "language": params.language})
29+
return (
30+
dict(
31+
await chain.ainvoke({"phrase": params.phrase, "language": params.language})
32+
).get("content")
33+
or ""
3134
)

langchain/langchain_interceptor.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Mapping, Protocol, Type
4+
5+
from temporalio import activity, api, client, converter, worker, workflow
6+
7+
with workflow.unsafe.imports_passed_through():
8+
from contextlib import contextmanager
9+
10+
from langsmith import trace, tracing_context
11+
from langsmith.run_helpers import get_current_run_tree
12+
13+
# Header key for LangChain context
14+
LANGCHAIN_CONTEXT_KEY = "langchain-context"
15+
16+
17+
class _InputWithHeaders(Protocol):
18+
headers: Mapping[str, api.common.v1.Payload]
19+
20+
21+
def set_header_from_context(
22+
input: _InputWithHeaders, payload_converter: converter.PayloadConverter
23+
) -> None:
24+
# Get current LangChain run tree
25+
run_tree = get_current_run_tree()
26+
if run_tree:
27+
headers = run_tree.to_headers()
28+
input.headers = {
29+
**input.headers,
30+
LANGCHAIN_CONTEXT_KEY: payload_converter.to_payload(headers),
31+
}
32+
33+
34+
@contextmanager
35+
def context_from_header(
36+
input: _InputWithHeaders, payload_converter: converter.PayloadConverter
37+
):
38+
payload = input.headers.get(LANGCHAIN_CONTEXT_KEY)
39+
if payload:
40+
run_tree = payload_converter.from_payload(payload, dict)
41+
# Set the run tree in the current context
42+
with tracing_context(parent=run_tree):
43+
yield
44+
else:
45+
yield
46+
47+
48+
class LangChainContextPropagationInterceptor(client.Interceptor, worker.Interceptor):
49+
"""Interceptor that propagates LangChain context through Temporal."""
50+
51+
def __init__(
52+
self,
53+
payload_converter: converter.PayloadConverter = converter.default().payload_converter,
54+
) -> None:
55+
self._payload_converter = payload_converter
56+
57+
def intercept_client(
58+
self, next: client.OutboundInterceptor
59+
) -> client.OutboundInterceptor:
60+
return _LangChainContextPropagationClientOutboundInterceptor(
61+
next, self._payload_converter
62+
)
63+
64+
def intercept_activity(
65+
self, next: worker.ActivityInboundInterceptor
66+
) -> worker.ActivityInboundInterceptor:
67+
return _LangChainContextPropagationActivityInboundInterceptor(next)
68+
69+
def workflow_interceptor_class(
70+
self, input: worker.WorkflowInterceptorClassInput
71+
) -> Type[_LangChainContextPropagationWorkflowInboundInterceptor]:
72+
return _LangChainContextPropagationWorkflowInboundInterceptor
73+
74+
75+
class _LangChainContextPropagationClientOutboundInterceptor(client.OutboundInterceptor):
76+
def __init__(
77+
self,
78+
next: client.OutboundInterceptor,
79+
payload_converter: converter.PayloadConverter,
80+
) -> None:
81+
super().__init__(next)
82+
self._payload_converter = payload_converter
83+
84+
async def start_workflow(
85+
self, input: client.StartWorkflowInput
86+
) -> client.WorkflowHandle[Any, Any]:
87+
with trace(name=f"start_workflow:{input.workflow}"):
88+
set_header_from_context(input, self._payload_converter)
89+
return await super().start_workflow(input)
90+
91+
92+
class _LangChainContextPropagationActivityInboundInterceptor(
93+
worker.ActivityInboundInterceptor
94+
):
95+
async def execute_activity(self, input: worker.ExecuteActivityInput) -> Any:
96+
if isinstance(input.fn, str):
97+
name = input.fn
98+
elif callable(input.fn):
99+
defn = activity._Definition.from_callable(input.fn)
100+
name = (
101+
defn.name if defn is not None and defn.name is not None else "unknown"
102+
)
103+
else:
104+
name = "unknown"
105+
106+
with context_from_header(input, activity.payload_converter()):
107+
with trace(name=f"execute_activity:{name}"):
108+
return await self.next.execute_activity(input)
109+
110+
111+
class _LangChainContextPropagationWorkflowInboundInterceptor(
112+
worker.WorkflowInboundInterceptor
113+
):
114+
def init(self, outbound: worker.WorkflowOutboundInterceptor) -> None:
115+
self.next.init(
116+
_LangChainContextPropagationWorkflowOutboundInterceptor(outbound)
117+
)
118+
119+
async def execute_workflow(self, input: worker.ExecuteWorkflowInput) -> Any:
120+
if isinstance(input.run_fn, str):
121+
name = input.run_fn
122+
elif callable(input.run_fn):
123+
defn = workflow._Definition.from_run_fn(input.run_fn)
124+
name = (
125+
defn.name if defn is not None and defn.name is not None else "unknown"
126+
)
127+
else:
128+
name = "unknown"
129+
130+
with context_from_header(input, workflow.payload_converter()):
131+
# This is a sandbox friendly way to write
132+
# with trace(...):
133+
# return await self.next.execute_workflow(input)
134+
with workflow.unsafe.sandbox_unrestricted():
135+
t = trace(
136+
name=f"execute_workflow:{name}", run_id=workflow.info().run_id
137+
)
138+
with workflow.unsafe.imports_passed_through():
139+
t.__enter__()
140+
try:
141+
return await self.next.execute_workflow(input)
142+
finally:
143+
with workflow.unsafe.sandbox_unrestricted():
144+
# Cannot use __aexit__ because it's internally uses
145+
# loop.run_in_executor which is not available in the sandbox
146+
t.__exit__()
147+
148+
149+
class _LangChainContextPropagationWorkflowOutboundInterceptor(
150+
worker.WorkflowOutboundInterceptor
151+
):
152+
def start_activity(
153+
self, input: worker.StartActivityInput
154+
) -> workflow.ActivityHandle:
155+
with workflow.unsafe.sandbox_unrestricted():
156+
t = trace(name=f"start_activity:{input.activity}", run_id=workflow.uuid4())
157+
with workflow.unsafe.imports_passed_through():
158+
t.__enter__()
159+
try:
160+
set_header_from_context(input, workflow.payload_converter())
161+
return self.next.start_activity(input)
162+
finally:
163+
with workflow.unsafe.sandbox_unrestricted():
164+
t.__exit__()
165+
166+
async def start_child_workflow(
167+
self, input: worker.StartChildWorkflowInput
168+
) -> workflow.ChildWorkflowHandle:
169+
with workflow.unsafe.sandbox_unrestricted():
170+
t = trace(
171+
name=f"start_child_workflow:{input.workflow}", run_id=workflow.uuid4()
172+
)
173+
with workflow.unsafe.imports_passed_through():
174+
t.__enter__()
175+
176+
try:
177+
set_header_from_context(input, workflow.payload_converter())
178+
return await self.next.start_child_workflow(input)
179+
finally:
180+
with workflow.unsafe.sandbox_unrestricted():
181+
t.__exit__()

langchain/starter.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,41 @@
11
from contextlib import asynccontextmanager
2+
from typing import List
23
from uuid import uuid4
34

45
import uvicorn
56
from activities import TranslateParams
67
from fastapi import FastAPI, HTTPException
8+
from langchain_interceptor import LangChainContextPropagationInterceptor
79
from temporalio.client import Client
8-
from workflow import LangChainWorkflow
10+
from workflow import LangChainWorkflow, TranslateWorkflowParams
911

1012

1113
@asynccontextmanager
1214
async def lifespan(app: FastAPI):
13-
app.state.temporal_client = await Client.connect("localhost:7233")
15+
app.state.temporal_client = await Client.connect(
16+
"localhost:7233", interceptors=[LangChainContextPropagationInterceptor()]
17+
)
1418
yield
1519

1620

1721
app = FastAPI(lifespan=lifespan)
1822

1923

2024
@app.post("/translate")
21-
async def translate(phrase: str, language: str):
25+
async def translate(phrase: str, language1: str, language2: str, language3: str):
26+
languages = [language1, language2, language3]
2227
client = app.state.temporal_client
2328
try:
2429
result = await client.execute_workflow(
2530
LangChainWorkflow.run,
26-
TranslateParams(phrase, language),
31+
TranslateWorkflowParams(phrase, languages),
2732
id=f"langchain-translation-{uuid4()}",
2833
task_queue="langchain-task-queue",
2934
)
30-
translation_content = result.get("content", "Translation not available")
3135
except Exception as e:
3236
raise HTTPException(status_code=500, detail=str(e))
3337

34-
return {"translation": translation_content}
38+
return {"translations": result}
3539

3640

3741
if __name__ == "__main__":

langchain/worker.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
22

33
from activities import translate_phrase
4+
from langchain_interceptor import LangChainContextPropagationInterceptor
45
from temporalio.client import Client
56
from temporalio.worker import Worker
6-
from workflow import LangChainWorkflow
7+
from workflow import LangChainChildWorkflow, LangChainWorkflow
78

89
interrupt_event = asyncio.Event()
910

@@ -13,8 +14,9 @@ async def main():
1314
worker = Worker(
1415
client,
1516
task_queue="langchain-task-queue",
16-
workflows=[LangChainWorkflow],
17+
workflows=[LangChainWorkflow, LangChainChildWorkflow],
1718
activities=[translate_phrase],
19+
interceptors=[LangChainContextPropagationInterceptor()],
1820
)
1921

2022
print("\nWorker started, ctrl+c to exit\n")
@@ -28,7 +30,8 @@ async def main():
2830

2931

3032
if __name__ == "__main__":
31-
loop = asyncio.get_event_loop()
33+
loop = asyncio.new_event_loop()
34+
asyncio.set_event_loop(loop)
3235
try:
3336
loop.run_until_complete(main())
3437
except KeyboardInterrupt:

langchain/workflow.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import asyncio
2+
from dataclasses import dataclass
13
from datetime import timedelta
4+
from typing import List
25

36
from temporalio import workflow
47

@@ -7,11 +10,44 @@
710

811

912
@workflow.defn
10-
class LangChainWorkflow:
13+
class LangChainChildWorkflow:
1114
@workflow.run
12-
async def run(self, params: TranslateParams) -> dict:
15+
async def run(self, params: TranslateParams) -> str:
1316
return await workflow.execute_activity(
1417
translate_phrase,
1518
params,
1619
schedule_to_close_timeout=timedelta(seconds=30),
1720
)
21+
22+
23+
@dataclass
24+
class TranslateWorkflowParams:
25+
phrase: str
26+
languages: List[str]
27+
28+
29+
@workflow.defn
30+
class LangChainWorkflow:
31+
@workflow.run
32+
async def run(self, params: TranslateWorkflowParams) -> dict:
33+
result1, result2, result3 = await asyncio.gather(
34+
workflow.execute_activity(
35+
translate_phrase,
36+
TranslateParams(params.phrase, params.languages[0]),
37+
schedule_to_close_timeout=timedelta(seconds=30),
38+
),
39+
workflow.execute_activity(
40+
translate_phrase,
41+
TranslateParams(params.phrase, params.languages[1]),
42+
schedule_to_close_timeout=timedelta(seconds=30),
43+
),
44+
workflow.execute_child_workflow(
45+
LangChainChildWorkflow.run,
46+
TranslateParams(params.phrase, params.languages[2]),
47+
),
48+
)
49+
return {
50+
params.languages[0]: result1,
51+
params.languages[1]: result2,
52+
params.languages[2]: result3,
53+
}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ gevent = ["gevent==25.4.2 ; python_version >= '3.8'"]
3838
langchain = [
3939
"langchain>=0.1.7,<0.2 ; python_version >= '3.8.1' and python_version < '4.0'",
4040
"langchain-openai>=0.0.6,<0.0.7 ; python_version >= '3.8.1' and python_version < '4.0'",
41+
"langsmith>=0.1.22,<0.2 ; python_version >= '3.8.1' and python_version < '4.0'",
4142
"openai>=1.4.0,<2",
4243
"fastapi>=0.105.0,<0.106",
4344
"tqdm>=4.62.0,<5",

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)