|
| 1 | +"""Tests for MCP tool-handler circuit-breaker recovery. |
| 2 | +
|
| 3 | +The circuit breaker in ``tools/mcp_tool.py`` is intended to short-circuit |
| 4 | +calls to an MCP server that has failed ``_CIRCUIT_BREAKER_THRESHOLD`` |
| 5 | +consecutive times, then *transition back to a usable state* once the |
| 6 | +server has had time to recover (or an explicit reconnect succeeds). |
| 7 | +
|
| 8 | +The original implementation only had two states — closed and open — with |
| 9 | +no mechanism to transition back to closed, so a tripped breaker stayed |
| 10 | +tripped for the lifetime of the process. These tests lock in the |
| 11 | +half-open / cooldown / reconnect-resets-breaker behavior that fixes |
| 12 | +that. |
| 13 | +""" |
| 14 | +import json |
| 15 | +from unittest.mock import MagicMock |
| 16 | + |
| 17 | +import pytest |
| 18 | + |
| 19 | + |
| 20 | +pytest.importorskip("mcp.client.auth.oauth2") |
| 21 | + |
| 22 | + |
| 23 | +# --------------------------------------------------------------------------- |
| 24 | +# Helpers |
| 25 | +# --------------------------------------------------------------------------- |
| 26 | + |
| 27 | + |
| 28 | +def _install_stub_server(mcp_tool_module, name: str, call_tool_impl): |
| 29 | + """Install a fake MCP server in the module's registry. |
| 30 | +
|
| 31 | + ``call_tool_impl`` is an async function stored at ``session.call_tool`` |
| 32 | + (it's what the tool handler invokes). |
| 33 | + """ |
| 34 | + server = MagicMock() |
| 35 | + server.name = name |
| 36 | + session = MagicMock() |
| 37 | + session.call_tool = call_tool_impl |
| 38 | + server.session = session |
| 39 | + server._reconnect_event = MagicMock() |
| 40 | + server._ready = MagicMock() |
| 41 | + server._ready.is_set.return_value = True |
| 42 | + |
| 43 | + mcp_tool_module._servers[name] = server |
| 44 | + mcp_tool_module._server_error_counts.pop(name, None) |
| 45 | + if hasattr(mcp_tool_module, "_server_breaker_opened_at"): |
| 46 | + mcp_tool_module._server_breaker_opened_at.pop(name, None) |
| 47 | + return server |
| 48 | + |
| 49 | + |
| 50 | +def _cleanup(mcp_tool_module, name: str) -> None: |
| 51 | + mcp_tool_module._servers.pop(name, None) |
| 52 | + mcp_tool_module._server_error_counts.pop(name, None) |
| 53 | + if hasattr(mcp_tool_module, "_server_breaker_opened_at"): |
| 54 | + mcp_tool_module._server_breaker_opened_at.pop(name, None) |
| 55 | + |
| 56 | + |
| 57 | +# --------------------------------------------------------------------------- |
| 58 | +# Tests |
| 59 | +# --------------------------------------------------------------------------- |
| 60 | + |
| 61 | + |
| 62 | +def test_circuit_breaker_half_opens_after_cooldown(monkeypatch, tmp_path): |
| 63 | + """After a tripped breaker's cooldown elapses, the *next* call must |
| 64 | + actually execute against the session (half-open probe). When the |
| 65 | + probe succeeds, the breaker resets to fully closed. |
| 66 | + """ |
| 67 | + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| 68 | + |
| 69 | + from tools import mcp_tool |
| 70 | + from tools.mcp_tool import _make_tool_handler |
| 71 | + |
| 72 | + call_count = {"n": 0} |
| 73 | + |
| 74 | + async def _call_tool_success(*a, **kw): |
| 75 | + call_count["n"] += 1 |
| 76 | + result = MagicMock() |
| 77 | + result.isError = False |
| 78 | + block = MagicMock() |
| 79 | + block.text = "ok" |
| 80 | + result.content = [block] |
| 81 | + result.structuredContent = None |
| 82 | + return result |
| 83 | + |
| 84 | + _install_stub_server(mcp_tool, "srv", _call_tool_success) |
| 85 | + mcp_tool._ensure_mcp_loop() |
| 86 | + |
| 87 | + try: |
| 88 | + # Trip the breaker by setting the count at/above threshold and |
| 89 | + # stamping the open-time to "now". |
| 90 | + mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD |
| 91 | + fake_now = [1000.0] |
| 92 | + |
| 93 | + def _fake_monotonic(): |
| 94 | + return fake_now[0] |
| 95 | + |
| 96 | + monkeypatch.setattr(mcp_tool.time, "monotonic", _fake_monotonic) |
| 97 | + # The breaker-open timestamp dict is introduced by the fix; on |
| 98 | + # a pre-fix build it won't exist, which will cause the test to |
| 99 | + # fail at the .get() inside the gate (correct — the fix is |
| 100 | + # required for this state to be tracked at all). |
| 101 | + if hasattr(mcp_tool, "_server_breaker_opened_at"): |
| 102 | + mcp_tool._server_breaker_opened_at["srv"] = fake_now[0] |
| 103 | + cooldown = getattr(mcp_tool, "_CIRCUIT_BREAKER_COOLDOWN_SEC", 60.0) |
| 104 | + |
| 105 | + handler = _make_tool_handler("srv", "tool1", 10.0) |
| 106 | + |
| 107 | + # Before cooldown: must short-circuit (no session call). |
| 108 | + result = handler({}) |
| 109 | + parsed = json.loads(result) |
| 110 | + assert "error" in parsed, parsed |
| 111 | + assert "unreachable" in parsed["error"].lower() |
| 112 | + assert call_count["n"] == 0, ( |
| 113 | + "breaker should short-circuit before cooldown elapses" |
| 114 | + ) |
| 115 | + |
| 116 | + # Advance past cooldown → next call is a half-open probe that |
| 117 | + # actually hits the session. |
| 118 | + fake_now[0] += cooldown + 1.0 |
| 119 | + |
| 120 | + result = handler({}) |
| 121 | + parsed = json.loads(result) |
| 122 | + assert parsed.get("result") == "ok", parsed |
| 123 | + assert call_count["n"] == 1, "half-open probe should invoke session" |
| 124 | + |
| 125 | + # On probe success the breaker must close (count reset to 0). |
| 126 | + assert mcp_tool._server_error_counts.get("srv", 0) == 0 |
| 127 | + finally: |
| 128 | + _cleanup(mcp_tool, "srv") |
| 129 | + |
| 130 | + |
| 131 | +def test_circuit_breaker_reopens_on_probe_failure(monkeypatch, tmp_path): |
| 132 | + """If the half-open probe fails, the breaker must re-arm the |
| 133 | + cooldown (not let every subsequent call through). |
| 134 | + """ |
| 135 | + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| 136 | + |
| 137 | + from tools import mcp_tool |
| 138 | + from tools.mcp_tool import _make_tool_handler |
| 139 | + |
| 140 | + call_count = {"n": 0} |
| 141 | + |
| 142 | + async def _call_tool_fails(*a, **kw): |
| 143 | + call_count["n"] += 1 |
| 144 | + raise RuntimeError("still broken") |
| 145 | + |
| 146 | + _install_stub_server(mcp_tool, "srv", _call_tool_fails) |
| 147 | + mcp_tool._ensure_mcp_loop() |
| 148 | + |
| 149 | + try: |
| 150 | + mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD |
| 151 | + fake_now = [1000.0] |
| 152 | + |
| 153 | + def _fake_monotonic(): |
| 154 | + return fake_now[0] |
| 155 | + |
| 156 | + monkeypatch.setattr(mcp_tool.time, "monotonic", _fake_monotonic) |
| 157 | + if hasattr(mcp_tool, "_server_breaker_opened_at"): |
| 158 | + mcp_tool._server_breaker_opened_at["srv"] = fake_now[0] |
| 159 | + cooldown = getattr(mcp_tool, "_CIRCUIT_BREAKER_COOLDOWN_SEC", 60.0) |
| 160 | + |
| 161 | + handler = _make_tool_handler("srv", "tool1", 10.0) |
| 162 | + |
| 163 | + # Advance past cooldown, run probe, expect failure. |
| 164 | + fake_now[0] += cooldown + 1.0 |
| 165 | + result = handler({}) |
| 166 | + parsed = json.loads(result) |
| 167 | + assert "error" in parsed |
| 168 | + assert call_count["n"] == 1, "probe should invoke session once" |
| 169 | + |
| 170 | + # The probe failure must have re-armed the cooldown — another |
| 171 | + # immediate call should short-circuit, not invoke session again. |
| 172 | + result = handler({}) |
| 173 | + parsed = json.loads(result) |
| 174 | + assert "unreachable" in parsed.get("error", "").lower() |
| 175 | + assert call_count["n"] == 1, ( |
| 176 | + "breaker should re-open and block further calls after probe failure" |
| 177 | + ) |
| 178 | + finally: |
| 179 | + _cleanup(mcp_tool, "srv") |
| 180 | + |
| 181 | + |
| 182 | +def test_circuit_breaker_cleared_on_reconnect(monkeypatch, tmp_path): |
| 183 | + """When the auth-recovery path successfully reconnects the server, |
| 184 | + the breaker should be cleared so subsequent calls aren't gated on a |
| 185 | + stale failure count — even if the post-reconnect retry itself fails. |
| 186 | +
|
| 187 | + This locks in the fix-#2 contract: a successful reconnect is |
| 188 | + sufficient evidence that the server is viable again. Under the old |
| 189 | + implementation, reset only happened on retry *success*, so a |
| 190 | + reconnect+retry-failure left the counter pinned above threshold |
| 191 | + forever. |
| 192 | + """ |
| 193 | + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) |
| 194 | + |
| 195 | + from tools import mcp_tool |
| 196 | + from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests |
| 197 | + from mcp.client.auth import OAuthFlowError |
| 198 | + |
| 199 | + reset_manager_for_tests() |
| 200 | + |
| 201 | + async def _call_tool_unused(*a, **kw): # pragma: no cover |
| 202 | + raise AssertionError("session.call_tool should not be reached in this test") |
| 203 | + |
| 204 | + _install_stub_server(mcp_tool, "srv", _call_tool_unused) |
| 205 | + mcp_tool._ensure_mcp_loop() |
| 206 | + |
| 207 | + # Open the breaker well above threshold, with a recent open-time so |
| 208 | + # it would short-circuit everything without a reset. |
| 209 | + mcp_tool._server_error_counts["srv"] = mcp_tool._CIRCUIT_BREAKER_THRESHOLD + 2 |
| 210 | + if hasattr(mcp_tool, "_server_breaker_opened_at"): |
| 211 | + import time as _time |
| 212 | + mcp_tool._server_breaker_opened_at["srv"] = _time.monotonic() |
| 213 | + |
| 214 | + # Force handle_401 to claim recovery succeeded. |
| 215 | + mgr = get_manager() |
| 216 | + |
| 217 | + async def _h401(name, token=None): |
| 218 | + return True |
| 219 | + |
| 220 | + monkeypatch.setattr(mgr, "handle_401", _h401) |
| 221 | + |
| 222 | + try: |
| 223 | + # Retry fails *after* the successful reconnect. Under the old |
| 224 | + # implementation this bumps an already-tripped counter even |
| 225 | + # higher. Under fix #2 the reset happens on successful |
| 226 | + # reconnect, and the post-retry bump only raises the fresh |
| 227 | + # count to 1 — still below threshold. |
| 228 | + def _retry_call(): |
| 229 | + raise OAuthFlowError("still failing post-reconnect") |
| 230 | + |
| 231 | + result = mcp_tool._handle_auth_error_and_retry( |
| 232 | + "srv", |
| 233 | + OAuthFlowError("initial"), |
| 234 | + _retry_call, |
| 235 | + "tools/call test", |
| 236 | + ) |
| 237 | + # The call as a whole still surfaces needs_reauth because the |
| 238 | + # retry itself didn't succeed, but the breaker state must |
| 239 | + # reflect the successful reconnect. |
| 240 | + assert result is not None |
| 241 | + parsed = json.loads(result) |
| 242 | + assert parsed.get("needs_reauth") is True, parsed |
| 243 | + |
| 244 | + # Post-reconnect count was reset to 0, then the failing retry |
| 245 | + # bumped it to exactly 1 — well below threshold. |
| 246 | + count = mcp_tool._server_error_counts.get("srv", 0) |
| 247 | + assert count < mcp_tool._CIRCUIT_BREAKER_THRESHOLD, ( |
| 248 | + f"successful reconnect must reset the breaker below threshold; " |
| 249 | + f"got count={count}, threshold={mcp_tool._CIRCUIT_BREAKER_THRESHOLD}" |
| 250 | + ) |
| 251 | + finally: |
| 252 | + _cleanup(mcp_tool, "srv") |
0 commit comments