|
| 1 | +"""Regression guard for PR #16660 (salvaged as PR #18027): ContextVar |
| 2 | +propagation into concurrent tool worker threads. |
| 3 | +
|
| 4 | +Background |
| 5 | +---------- |
| 6 | +Gateway adapters (Slack, Telegram, Discord, ...) set |
| 7 | +``tools.approval._approval_session_key`` as a ContextVar before calling |
| 8 | +``agent.run_conversation`` so that dangerous-command approval prompts route |
| 9 | +back to the channel/session that initiated the tool call. When the agent |
| 10 | +dispatches multiple tools in parallel, it uses |
| 11 | +``concurrent.futures.ThreadPoolExecutor.submit(...)`` — and ``submit`` runs |
| 12 | +the callable in a *fresh* context, NOT the caller's context. Without an |
| 13 | +explicit ``contextvars.copy_context().run(...)`` wrapper, worker threads |
| 14 | +observe the ContextVar's default value, fall through to the |
| 15 | +``os.environ`` legacy fallback (which the gateway overwrites at each |
| 16 | +agent step), and route the approval card to *whichever session stepped |
| 17 | +most recently* — not the one that raised the prompt. Confirmed in the |
| 18 | +wild on Slack with two concurrent channels: session A's `rm -rf` |
| 19 | +approval card was delivered to session B. |
| 20 | +
|
| 21 | +The fix (4 LOC in ``run_agent.py``) snapshots the caller's context with |
| 22 | +``copy_context()`` and submits ``ctx.run(_run_tool, …)`` instead of |
| 23 | +``_run_tool`` directly. Mirrors ``asyncio.to_thread`` semantics. |
| 24 | +
|
| 25 | +This suite follows the ``contextvar-run-in-executor-bridge`` skill's |
| 26 | +two-test pattern: one end-to-end test proves the fix works at the |
| 27 | +call-site level, one documents the Python contract that makes the fix |
| 28 | +necessary. If anyone ever reverts the wrapper, the call-site test |
| 29 | +fails while the contract test keeps passing — a clear diagnostic |
| 30 | +signal for *why* the call-site regressed. |
| 31 | +""" |
| 32 | + |
| 33 | +from __future__ import annotations |
| 34 | + |
| 35 | +import concurrent.futures |
| 36 | +import contextvars |
| 37 | +import threading |
| 38 | + |
| 39 | + |
| 40 | +def test_executor_submit_without_copy_context_does_not_propagate(): |
| 41 | + """Documents the Python contract the fix relies on. |
| 42 | +
|
| 43 | + ``concurrent.futures.ThreadPoolExecutor.submit(fn)`` runs ``fn`` in a |
| 44 | + worker thread with a fresh, empty context. A ContextVar set by the |
| 45 | + caller is invisible inside ``fn``. This is the exact trap that made |
| 46 | + approval-session routing race in the gateway before #16660. |
| 47 | +
|
| 48 | + If this test ever fails — i.e. submit() starts propagating |
| 49 | + ContextVars by default — the copy_context() wrapper in run_agent.py |
| 50 | + becomes redundant but not harmful, and the call-site test below |
| 51 | + should be updated accordingly. |
| 52 | + """ |
| 53 | + probe: contextvars.ContextVar[str] = contextvars.ContextVar( |
| 54 | + "probe_default_propagation", default="unset" |
| 55 | + ) |
| 56 | + |
| 57 | + def read_in_worker() -> str: |
| 58 | + return probe.get() |
| 59 | + |
| 60 | + probe.set("set-in-main") |
| 61 | + |
| 62 | + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex: |
| 63 | + observed = ex.submit(read_in_worker).result(timeout=5) |
| 64 | + |
| 65 | + assert observed == "unset", ( |
| 66 | + "Unexpected: executor.submit propagated a ContextVar without " |
| 67 | + "copy_context(). If Python's behavior changed, update " |
| 68 | + "test_run_tool_worker_sees_parent_context below." |
| 69 | + ) |
| 70 | + |
| 71 | + |
| 72 | +def test_executor_submit_with_copy_context_run_propagates(): |
| 73 | + """Positive case: the explicit ``copy_context().run(...)`` wrapper the |
| 74 | + PR adds makes parent-context ContextVar values visible in the worker. |
| 75 | + """ |
| 76 | + probe: contextvars.ContextVar[str] = contextvars.ContextVar( |
| 77 | + "probe_explicit_propagation", default="unset" |
| 78 | + ) |
| 79 | + |
| 80 | + def read_in_worker() -> str: |
| 81 | + return probe.get() |
| 82 | + |
| 83 | + probe.set("set-in-main") |
| 84 | + |
| 85 | + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex: |
| 86 | + ctx = contextvars.copy_context() |
| 87 | + observed = ex.submit(ctx.run, read_in_worker).result(timeout=5) |
| 88 | + |
| 89 | + assert observed == "set-in-main", ( |
| 90 | + f"copy_context().run(...) failed to propagate: got {observed!r}" |
| 91 | + ) |
| 92 | + |
| 93 | + |
| 94 | +def test_run_tool_worker_sees_parent_approval_session_key(): |
| 95 | + """End-to-end call-site guard. |
| 96 | +
|
| 97 | + Mirrors the exact shape of the fixed call site in |
| 98 | + ``run_agent.py::_execute_tool_calls_concurrent`` — a |
| 99 | + ``ThreadPoolExecutor`` with ``executor.submit(ctx.run, fn, *args)``. |
| 100 | + Sets the real ``tools.approval._approval_session_key`` ContextVar |
| 101 | + in the caller and asserts the worker observes it via |
| 102 | + ``tools.approval.get_current_session_key()``. |
| 103 | +
|
| 104 | + If the PR's ``copy_context().run`` wrapper is reverted, this test |
| 105 | + fails with ``Expected 'session-A' but worker saw 'default'``. |
| 106 | + """ |
| 107 | + from tools.approval import ( |
| 108 | + _approval_session_key, |
| 109 | + get_current_session_key, |
| 110 | + ) |
| 111 | + |
| 112 | + observed: dict = {} |
| 113 | + barrier = threading.Event() |
| 114 | + |
| 115 | + def worker_equivalent_to_run_tool() -> None: |
| 116 | + # Mirror what real _run_tool does early: read the session key. |
| 117 | + observed["session_key"] = get_current_session_key(default="FALLBACK") |
| 118 | + barrier.set() |
| 119 | + |
| 120 | + # Set the ContextVar the gateway would set before calling agent.run. |
| 121 | + token = _approval_session_key.set("session-A") |
| 122 | + try: |
| 123 | + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex: |
| 124 | + ctx = contextvars.copy_context() |
| 125 | + fut = ex.submit(ctx.run, worker_equivalent_to_run_tool) |
| 126 | + fut.result(timeout=5) |
| 127 | + assert barrier.is_set(), "worker did not complete" |
| 128 | + finally: |
| 129 | + _approval_session_key.reset(token) |
| 130 | + |
| 131 | + assert observed.get("session_key") == "session-A", ( |
| 132 | + f"Worker thread did not inherit _approval_session_key from caller. " |
| 133 | + f"Expected 'session-A', got {observed.get('session_key')!r}. " |
| 134 | + "This is the bug that PR #16660 fixed — approval prompts route to " |
| 135 | + "the wrong session in concurrent gateway traffic. Check whether " |
| 136 | + "the copy_context().run wrapper in _execute_tool_calls_concurrent " |
| 137 | + "was removed." |
| 138 | + ) |
| 139 | + |
| 140 | + |
| 141 | +def test_run_agent_concurrent_executor_wraps_submit_with_copy_context(): |
| 142 | + """Source-level guard that the fix stays at the REAL call site. |
| 143 | +
|
| 144 | + The behavioral tests above exercise the pattern in isolation and |
| 145 | + pass regardless of whether ``run_agent.py`` actually uses it. |
| 146 | + This guard inspects ``_execute_tool_calls_concurrent`` directly and |
| 147 | + asserts that ``executor.submit`` is called with ``ctx.run`` (or |
| 148 | + ``copy_context()`` appears within a few lines) — so reverting the |
| 149 | + wrapper in ``run_agent.py`` fails this test with a clear message. |
| 150 | + """ |
| 151 | + import ast |
| 152 | + import inspect |
| 153 | + |
| 154 | + import run_agent |
| 155 | + |
| 156 | + src_path = inspect.getsourcefile(run_agent) |
| 157 | + assert src_path is not None |
| 158 | + tree = ast.parse(open(src_path, encoding="utf-8").read()) |
| 159 | + |
| 160 | + submit_calls_in_agent: list[ast.Call] = [] |
| 161 | + for node in ast.walk(tree): |
| 162 | + if not isinstance(node, ast.Call): |
| 163 | + continue |
| 164 | + func = node.func |
| 165 | + # Match executor.submit(...) style calls. |
| 166 | + if isinstance(func, ast.Attribute) and func.attr == "submit": |
| 167 | + submit_calls_in_agent.append(node) |
| 168 | + |
| 169 | + # Filter to the submit call inside the concurrent tool executor — |
| 170 | + # identifiable by passing `_run_tool` as its target. Other submit() |
| 171 | + # call sites in run_agent.py (e.g. auxiliary client warm-up) are |
| 172 | + # out of scope for this regression. |
| 173 | + tool_submits = [] |
| 174 | + for call in submit_calls_in_agent: |
| 175 | + if not call.args: |
| 176 | + continue |
| 177 | + first = call.args[0] |
| 178 | + # Unfixed: executor.submit(_run_tool, ...) → first arg is a Name |
| 179 | + if isinstance(first, ast.Name) and first.id == "_run_tool": |
| 180 | + tool_submits.append(("unfixed", call)) |
| 181 | + # Fixed: executor.submit(ctx.run, _run_tool, ...) → first arg is |
| 182 | + # ctx.run (Attribute), and _run_tool is the second arg. |
| 183 | + elif ( |
| 184 | + isinstance(first, ast.Attribute) |
| 185 | + and first.attr == "run" |
| 186 | + and len(call.args) >= 2 |
| 187 | + and isinstance(call.args[1], ast.Name) |
| 188 | + and call.args[1].id == "_run_tool" |
| 189 | + ): |
| 190 | + tool_submits.append(("fixed", call)) |
| 191 | + |
| 192 | + assert tool_submits, ( |
| 193 | + "Could not locate `executor.submit(... _run_tool ...)` in " |
| 194 | + "run_agent.py. The call site may have been renamed — update this " |
| 195 | + "guard along with the refactor." |
| 196 | + ) |
| 197 | + unfixed = [c for kind, c in tool_submits if kind == "unfixed"] |
| 198 | + assert not unfixed, ( |
| 199 | + "run_agent.py contains `executor.submit(_run_tool, ...)` without a " |
| 200 | + "`ctx.run` wrapper. This is the pre-#16660 shape: worker threads " |
| 201 | + "will read a fresh ContextVar and approval-session routing " |
| 202 | + "collapses to the os.environ fallback. Wrap with " |
| 203 | + "`ctx = contextvars.copy_context(); executor.submit(ctx.run, " |
| 204 | + "_run_tool, ...)`." |
| 205 | + ) |
| 206 | + |
| 207 | + |
| 208 | +def test_two_concurrent_tool_batches_keep_session_keys_isolated(): |
| 209 | + """End-to-end guard: two callers each set a different session key |
| 210 | + and submit workers concurrently. Each worker must see its own |
| 211 | + caller's key, not the other's. |
| 212 | +
|
| 213 | + Guards against a future "optimization" that reuses a single context |
| 214 | + snapshot across callers (which would collapse isolation the same way |
| 215 | + the unfixed ``submit`` does). |
| 216 | + """ |
| 217 | + from tools.approval import ( |
| 218 | + _approval_session_key, |
| 219 | + get_current_session_key, |
| 220 | + ) |
| 221 | + |
| 222 | + results: dict = {} |
| 223 | + |
| 224 | + def caller(label: str) -> None: |
| 225 | + token = _approval_session_key.set(f"session-{label}") |
| 226 | + try: |
| 227 | + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex: |
| 228 | + ctx = contextvars.copy_context() |
| 229 | + fut = ex.submit( |
| 230 | + ctx.run, |
| 231 | + lambda: get_current_session_key(default="FALLBACK"), |
| 232 | + ) |
| 233 | + results[label] = fut.result(timeout=5) |
| 234 | + finally: |
| 235 | + _approval_session_key.reset(token) |
| 236 | + |
| 237 | + t_a = threading.Thread(target=caller, args=("A",)) |
| 238 | + t_b = threading.Thread(target=caller, args=("B",)) |
| 239 | + t_a.start() |
| 240 | + t_b.start() |
| 241 | + t_a.join(timeout=10) |
| 242 | + t_b.join(timeout=10) |
| 243 | + |
| 244 | + assert results.get("A") == "session-A", ( |
| 245 | + f"Session A worker saw {results.get('A')!r}, expected 'session-A'" |
| 246 | + ) |
| 247 | + assert results.get("B") == "session-B", ( |
| 248 | + f"Session B worker saw {results.get('B')!r}, expected 'session-B'" |
| 249 | + ) |
0 commit comments