Skip to content

Commit e5dad4a

Browse files
fix(agent): propagate ContextVars to concurrent tool worker threads (#18123)
Propagates ContextVars (notably `tools.approval._approval_session_key`) into concurrent tool worker threads via `copy_context().run` — mirrors `asyncio.to_thread` semantics. Fixes approval-card cross-session misrouting in concurrent gateway traffic. Repro'd on Slack: session A's dangerous-command approval was delivered to channel B (@syahidfrd). Salvages #16660 — core 4-LOC fix preserved, unrelated `tests/eval_018/` scope contamination dropped. Adds 5 regression guards including an AST-level source check on the real call site. Closes #16660. Co-authored-by: firefly <promptsiren@gmail.com> Co-authored-by: banditburai <banditburai@users.noreply.github.com>
1 parent 180a703 commit e5dad4a

2 files changed

Lines changed: 253 additions & 1 deletion

File tree

run_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import asyncio
2424
import base64
2525
import concurrent.futures
26+
import contextvars
2627
import copy
2728
import hashlib
2829
import json
@@ -9443,7 +9444,9 @@ def _run_tool(index, tool_call, function_name, function_args):
94439444
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
94449445
futures = []
94459446
for i, (tc, name, args) in enumerate(parsed_calls):
9446-
f = executor.submit(_run_tool, i, tc, name, args)
9447+
# Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread.
9448+
ctx = contextvars.copy_context()
9449+
f = executor.submit(ctx.run, _run_tool, i, tc, name, args)
94479450
futures.append(f)
94489451

94499452
# Wait for all to complete with periodic heartbeats so the
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
"""Regression guard for PR #16660 (salvaged as PR #18027): ContextVar
2+
propagation into concurrent tool worker threads.
3+
4+
Background
5+
----------
6+
Gateway adapters (Slack, Telegram, Discord, ...) set
7+
``tools.approval._approval_session_key`` as a ContextVar before calling
8+
``agent.run_conversation`` so that dangerous-command approval prompts route
9+
back to the channel/session that initiated the tool call. When the agent
10+
dispatches multiple tools in parallel, it uses
11+
``concurrent.futures.ThreadPoolExecutor.submit(...)`` — and ``submit`` runs
12+
the callable in a *fresh* context, NOT the caller's context. Without an
13+
explicit ``contextvars.copy_context().run(...)`` wrapper, worker threads
14+
observe the ContextVar's default value, fall through to the
15+
``os.environ`` legacy fallback (which the gateway overwrites at each
16+
agent step), and route the approval card to *whichever session stepped
17+
most recently* — not the one that raised the prompt. Confirmed in the
18+
wild on Slack with two concurrent channels: session A's `rm -rf`
19+
approval card was delivered to session B.
20+
21+
The fix (4 LOC in ``run_agent.py``) snapshots the caller's context with
22+
``copy_context()`` and submits ``ctx.run(_run_tool, …)`` instead of
23+
``_run_tool`` directly. Mirrors ``asyncio.to_thread`` semantics.
24+
25+
This suite follows the ``contextvar-run-in-executor-bridge`` skill's
26+
two-test pattern: one end-to-end test proves the fix works at the
27+
call-site level, one documents the Python contract that makes the fix
28+
necessary. If anyone ever reverts the wrapper, the call-site test
29+
fails while the contract test keeps passing — a clear diagnostic
30+
signal for *why* the call-site regressed.
31+
"""
32+
33+
from __future__ import annotations
34+
35+
import concurrent.futures
36+
import contextvars
37+
import threading
38+
39+
40+
def test_executor_submit_without_copy_context_does_not_propagate():
41+
"""Documents the Python contract the fix relies on.
42+
43+
``concurrent.futures.ThreadPoolExecutor.submit(fn)`` runs ``fn`` in a
44+
worker thread with a fresh, empty context. A ContextVar set by the
45+
caller is invisible inside ``fn``. This is the exact trap that made
46+
approval-session routing race in the gateway before #16660.
47+
48+
If this test ever fails — i.e. submit() starts propagating
49+
ContextVars by default — the copy_context() wrapper in run_agent.py
50+
becomes redundant but not harmful, and the call-site test below
51+
should be updated accordingly.
52+
"""
53+
probe: contextvars.ContextVar[str] = contextvars.ContextVar(
54+
"probe_default_propagation", default="unset"
55+
)
56+
57+
def read_in_worker() -> str:
58+
return probe.get()
59+
60+
probe.set("set-in-main")
61+
62+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
63+
observed = ex.submit(read_in_worker).result(timeout=5)
64+
65+
assert observed == "unset", (
66+
"Unexpected: executor.submit propagated a ContextVar without "
67+
"copy_context(). If Python's behavior changed, update "
68+
"test_run_tool_worker_sees_parent_context below."
69+
)
70+
71+
72+
def test_executor_submit_with_copy_context_run_propagates():
73+
"""Positive case: the explicit ``copy_context().run(...)`` wrapper the
74+
PR adds makes parent-context ContextVar values visible in the worker.
75+
"""
76+
probe: contextvars.ContextVar[str] = contextvars.ContextVar(
77+
"probe_explicit_propagation", default="unset"
78+
)
79+
80+
def read_in_worker() -> str:
81+
return probe.get()
82+
83+
probe.set("set-in-main")
84+
85+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
86+
ctx = contextvars.copy_context()
87+
observed = ex.submit(ctx.run, read_in_worker).result(timeout=5)
88+
89+
assert observed == "set-in-main", (
90+
f"copy_context().run(...) failed to propagate: got {observed!r}"
91+
)
92+
93+
94+
def test_run_tool_worker_sees_parent_approval_session_key():
95+
"""End-to-end call-site guard.
96+
97+
Mirrors the exact shape of the fixed call site in
98+
``run_agent.py::_execute_tool_calls_concurrent`` — a
99+
``ThreadPoolExecutor`` with ``executor.submit(ctx.run, fn, *args)``.
100+
Sets the real ``tools.approval._approval_session_key`` ContextVar
101+
in the caller and asserts the worker observes it via
102+
``tools.approval.get_current_session_key()``.
103+
104+
If the PR's ``copy_context().run`` wrapper is reverted, this test
105+
fails with ``Expected 'session-A' but worker saw 'default'``.
106+
"""
107+
from tools.approval import (
108+
_approval_session_key,
109+
get_current_session_key,
110+
)
111+
112+
observed: dict = {}
113+
barrier = threading.Event()
114+
115+
def worker_equivalent_to_run_tool() -> None:
116+
# Mirror what real _run_tool does early: read the session key.
117+
observed["session_key"] = get_current_session_key(default="FALLBACK")
118+
barrier.set()
119+
120+
# Set the ContextVar the gateway would set before calling agent.run.
121+
token = _approval_session_key.set("session-A")
122+
try:
123+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
124+
ctx = contextvars.copy_context()
125+
fut = ex.submit(ctx.run, worker_equivalent_to_run_tool)
126+
fut.result(timeout=5)
127+
assert barrier.is_set(), "worker did not complete"
128+
finally:
129+
_approval_session_key.reset(token)
130+
131+
assert observed.get("session_key") == "session-A", (
132+
f"Worker thread did not inherit _approval_session_key from caller. "
133+
f"Expected 'session-A', got {observed.get('session_key')!r}. "
134+
"This is the bug that PR #16660 fixed — approval prompts route to "
135+
"the wrong session in concurrent gateway traffic. Check whether "
136+
"the copy_context().run wrapper in _execute_tool_calls_concurrent "
137+
"was removed."
138+
)
139+
140+
141+
def test_run_agent_concurrent_executor_wraps_submit_with_copy_context():
142+
"""Source-level guard that the fix stays at the REAL call site.
143+
144+
The behavioral tests above exercise the pattern in isolation and
145+
pass regardless of whether ``run_agent.py`` actually uses it.
146+
This guard inspects ``_execute_tool_calls_concurrent`` directly and
147+
asserts that ``executor.submit`` is called with ``ctx.run`` (or
148+
``copy_context()`` appears within a few lines) — so reverting the
149+
wrapper in ``run_agent.py`` fails this test with a clear message.
150+
"""
151+
import ast
152+
import inspect
153+
154+
import run_agent
155+
156+
src_path = inspect.getsourcefile(run_agent)
157+
assert src_path is not None
158+
tree = ast.parse(open(src_path, encoding="utf-8").read())
159+
160+
submit_calls_in_agent: list[ast.Call] = []
161+
for node in ast.walk(tree):
162+
if not isinstance(node, ast.Call):
163+
continue
164+
func = node.func
165+
# Match executor.submit(...) style calls.
166+
if isinstance(func, ast.Attribute) and func.attr == "submit":
167+
submit_calls_in_agent.append(node)
168+
169+
# Filter to the submit call inside the concurrent tool executor —
170+
# identifiable by passing `_run_tool` as its target. Other submit()
171+
# call sites in run_agent.py (e.g. auxiliary client warm-up) are
172+
# out of scope for this regression.
173+
tool_submits = []
174+
for call in submit_calls_in_agent:
175+
if not call.args:
176+
continue
177+
first = call.args[0]
178+
# Unfixed: executor.submit(_run_tool, ...) → first arg is a Name
179+
if isinstance(first, ast.Name) and first.id == "_run_tool":
180+
tool_submits.append(("unfixed", call))
181+
# Fixed: executor.submit(ctx.run, _run_tool, ...) → first arg is
182+
# ctx.run (Attribute), and _run_tool is the second arg.
183+
elif (
184+
isinstance(first, ast.Attribute)
185+
and first.attr == "run"
186+
and len(call.args) >= 2
187+
and isinstance(call.args[1], ast.Name)
188+
and call.args[1].id == "_run_tool"
189+
):
190+
tool_submits.append(("fixed", call))
191+
192+
assert tool_submits, (
193+
"Could not locate `executor.submit(... _run_tool ...)` in "
194+
"run_agent.py. The call site may have been renamed — update this "
195+
"guard along with the refactor."
196+
)
197+
unfixed = [c for kind, c in tool_submits if kind == "unfixed"]
198+
assert not unfixed, (
199+
"run_agent.py contains `executor.submit(_run_tool, ...)` without a "
200+
"`ctx.run` wrapper. This is the pre-#16660 shape: worker threads "
201+
"will read a fresh ContextVar and approval-session routing "
202+
"collapses to the os.environ fallback. Wrap with "
203+
"`ctx = contextvars.copy_context(); executor.submit(ctx.run, "
204+
"_run_tool, ...)`."
205+
)
206+
207+
208+
def test_two_concurrent_tool_batches_keep_session_keys_isolated():
209+
"""End-to-end guard: two callers each set a different session key
210+
and submit workers concurrently. Each worker must see its own
211+
caller's key, not the other's.
212+
213+
Guards against a future "optimization" that reuses a single context
214+
snapshot across callers (which would collapse isolation the same way
215+
the unfixed ``submit`` does).
216+
"""
217+
from tools.approval import (
218+
_approval_session_key,
219+
get_current_session_key,
220+
)
221+
222+
results: dict = {}
223+
224+
def caller(label: str) -> None:
225+
token = _approval_session_key.set(f"session-{label}")
226+
try:
227+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
228+
ctx = contextvars.copy_context()
229+
fut = ex.submit(
230+
ctx.run,
231+
lambda: get_current_session_key(default="FALLBACK"),
232+
)
233+
results[label] = fut.result(timeout=5)
234+
finally:
235+
_approval_session_key.reset(token)
236+
237+
t_a = threading.Thread(target=caller, args=("A",))
238+
t_b = threading.Thread(target=caller, args=("B",))
239+
t_a.start()
240+
t_b.start()
241+
t_a.join(timeout=10)
242+
t_b.join(timeout=10)
243+
244+
assert results.get("A") == "session-A", (
245+
f"Session A worker saw {results.get('A')!r}, expected 'session-A'"
246+
)
247+
assert results.get("B") == "session-B", (
248+
f"Session B worker saw {results.get('B')!r}, expected 'session-B'"
249+
)

0 commit comments

Comments
 (0)