Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 81 additions & 20 deletions tools/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,14 +906,43 @@ async def _wait_for_lifecycle_event(self) -> str:
with a fresh signal.

Shutdown takes precedence if both events are set simultaneously.

Periodically sends a lightweight keepalive (``list_tools``) to
prevent TCP connections from going stale during long idle
periods (#17003). If the keepalive fails, triggers a reconnect.
"""
# Keepalive interval in seconds. Must be shorter than typical
# LB / NAT idle-timeout (commonly 300-600s).
_KEEPALIVE_INTERVAL = 180 # 3 minutes

shutdown_task = asyncio.create_task(self._shutdown_event.wait())
reconnect_task = asyncio.create_task(self._reconnect_event.wait())
try:
await asyncio.wait(
{shutdown_task, reconnect_task},
return_when=asyncio.FIRST_COMPLETED,
)
while True:
done, _pending = await asyncio.wait(
{shutdown_task, reconnect_task},
timeout=_KEEPALIVE_INTERVAL,
return_when=asyncio.FIRST_COMPLETED,
)
if done:
break

# Timeout — no lifecycle event fired. Send a keepalive
# to exercise the connection and detect stale sockets.
if self.session:
try:
await asyncio.wait_for(
self.session.list_tools(),
timeout=30.0,
)
except Exception as exc:
logger.warning(
"MCP server '%s' keepalive failed, "
"triggering reconnect: %s",
self.name, exc,
)
self._reconnect_event.set()
return "reconnect"
finally:
for t in (shutdown_task, reconnect_task):
if not t.done():
Expand Down Expand Up @@ -1253,6 +1282,25 @@ async def shutdown(self):
_server_error_counts: Dict[str, int] = {}
_CIRCUIT_BREAKER_THRESHOLD = 3

# Half-open recovery: after the breaker trips, wait this many seconds
# before allowing a single probe call through. If the probe succeeds
# the breaker resets; if it fails the breaker re-opens. (#16788)
_CIRCUIT_BREAKER_COOLDOWN_SEC = 60
_server_breaker_opened_at: Dict[str, float] = {}


def _bump_server_error(server_name: str) -> None:
"""Increment error count and record breaker-open timestamp if threshold reached."""
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
if _server_error_counts[server_name] >= _CIRCUIT_BREAKER_THRESHOLD:
_server_breaker_opened_at[server_name] = time.time()


def _reset_server_error(server_name: str) -> None:
"""Reset error count and clear breaker-open timestamp on success."""
_server_error_counts[server_name] = 0
_server_breaker_opened_at.pop(server_name, None)

# ---------------------------------------------------------------------------
# Auth-failure detection helpers (Task 6 of MCP OAuth consolidation)
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1396,10 +1444,10 @@ async def _recover():
try:
parsed = json.loads(result)
if "error" not in parsed:
_server_error_counts[server_name] = 0
_reset_server_error(server_name)
return result
except (json.JSONDecodeError, TypeError):
_server_error_counts[server_name] = 0
_reset_server_error(server_name)
return result
except Exception as retry_exc:
logger.warning(
Expand All @@ -1410,7 +1458,7 @@ async def _recover():
# No recovery available, or retry also failed: surface a structured
# needs_reauth error. Bumps the circuit breaker so the model stops
# retrying the tool.
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
_bump_server_error(server_name)
return json.dumps({
"error": (
f"MCP server '{server_name}' requires re-authentication. "
Expand Down Expand Up @@ -1615,20 +1663,33 @@ def _handler(args: dict, **kwargs) -> str:
# Circuit breaker: if this server has failed too many times
# consecutively, short-circuit with a clear message so the model
# stops retrying and uses alternative approaches (#10447).
# After _CIRCUIT_BREAKER_COOLDOWN_SEC, allow a single probe call
# through (half-open state) to recover from transient subprocess
# deaths (#16788).
if _server_error_counts.get(server_name, 0) >= _CIRCUIT_BREAKER_THRESHOLD:
return json.dumps({
"error": (
f"MCP server '{server_name}' is unreachable after "
f"{_CIRCUIT_BREAKER_THRESHOLD} consecutive failures. "
f"Do NOT retry this tool — use alternative approaches "
f"or ask the user to check the MCP server."
)
}, ensure_ascii=False)
opened = _server_breaker_opened_at.get(server_name, 0)
if opened and (time.time() - opened) < _CIRCUIT_BREAKER_COOLDOWN_SEC:
# Breaker still in open state — block the call
return json.dumps({
"error": (
f"MCP server '{server_name}' is unreachable after "
f"{_CIRCUIT_BREAKER_THRESHOLD} consecutive failures. "
f"Retry in {_CIRCUIT_BREAKER_COOLDOWN_SEC}s — "
f"do NOT call this tool again until then."
)
}, ensure_ascii=False)
# Cooldown elapsed — enter half-open state: allow one probe.
# If the probe succeeds, _server_error_counts is reset below.
# If it fails, the count is bumped and a new cooldown starts.
logger.info(
"MCP server '%s': circuit breaker half-open, allowing probe call",
server_name,
)

with _lock:
server = _servers.get(server_name)
if not server or not server.session:
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
_bump_server_error(server_name)
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
}, ensure_ascii=False)
Expand Down Expand Up @@ -1677,11 +1738,11 @@ def _call_once():
try:
parsed = json.loads(result)
if "error" in parsed:
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
_bump_server_error(server_name)
else:
_server_error_counts[server_name] = 0 # success — reset
_reset_server_error(server_name) # success — reset
except (json.JSONDecodeError, TypeError):
_server_error_counts[server_name] = 0 # non-JSON = success
_reset_server_error(server_name) # non-JSON = success
return result
except InterruptedError:
return _interrupted_call_result()
Expand All @@ -1696,7 +1757,7 @@ def _call_once():
if recovered is not None:
return recovered

_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
_bump_server_error(server_name)
logger.error(
"MCP tool %s/%s call failed: %s",
server_name, tool_name, exc,
Expand Down