Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 199 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import asyncio
from urllib.parse import urlparse, parse_qs
import socket
import warnings
import brotli
Expand Down Expand Up @@ -51,24 +53,33 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional
from typing import Any, Callable, MutableMapping, Optional
from collections.abc import Iterator

try:
from anyio import create_memory_object_stream, create_task_group
from anyio import create_memory_object_stream, create_task_group, EndOfStream
from mcp.types import (
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
)
from mcp.shared.message import SessionMessage
from httpx import ASGITransport, Request, Response, AsyncByteStream, AsyncClient
except ImportError:
create_memory_object_stream = None
create_task_group = None
EndOfStream = None

JSONRPCMessage = None
JSONRPCRequest = None
SessionMessage = None

ASGITransport = None
Request = None
Response = None
AsyncByteStream = None
AsyncClient = None


SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json"

Expand Down Expand Up @@ -786,6 +797,192 @@ def inner(events):
return inner


@pytest.fixture()
def json_rpc_sse(is_structured_content: bool = True):
class StreamingASGITransport(ASGITransport):
"""
Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
tests involving SSE interactions to run in-process.
"""

def __init__(
self,
app: "Callable",
keep_sse_alive: "asyncio.Event",
) -> None:
self.keep_sse_alive = keep_sse_alive
super().__init__(app)

async def handle_async_request(self, request: "Request") -> "Response":
scope = {
"type": "http",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"path": request.url.path,
"query_string": request.url.query,
}

is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
if not is_streaming_sse:
return await super().handle_async_request(request)

request_body = b""
if request.content:
request_body = await request.aread()

body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore

async def receive() -> "dict[str, Any]":
if self.keep_sse_alive.is_set():
return {"type": "http.disconnect"}

await self.keep_sse_alive.wait() # Keep alive :)
return {
"type": "http.request",
"body": request_body,
"more_body": False,
}

async def send(message: "MutableMapping[str, Any]") -> None:
if message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body == b"" and not more_body:
return

if body:
await body_sender.send(body)

if not more_body:
await body_sender.aclose()

async def run_app():
await self.app(scope, receive, send)

class StreamingBodyStream(AsyncByteStream): # type: ignore
def __init__(self, receiver, task):
self.receiver = receiver
self.task = task

async def __aiter__(self):
try:
async for chunk in self.receiver:
yield chunk
except EndOfStream: # type: ignore
pass

stream = StreamingBodyStream(body_receiver, asyncio.create_task(run_app()))
response = Response(status_code=200, headers=[], stream=stream) # type: ignore

return response

def parse_sse_data_package(sse_chunk):
sse_text = sse_chunk.decode("utf-8")
json_str = sse_text.split("data: ")[1]
return json.loads(json_str)

async def inner(
app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
):
context = {}

stream_complete = asyncio.Event()
endpoint_parsed = asyncio.Event()

# https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
async with AsyncClient( # type: ignore
transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
base_url="http://test",
) as client:

async def parse_stream():
async with client.stream("GET", "/sse") as stream:
# Read directly from stream.stream instead of aiter_bytes()
async for chunk in stream.stream:
if b"event: endpoint" in chunk:
sse_text = chunk.decode("utf-8")
url = sse_text.split("data: ")[1]

parsed = urlparse(url)
query_params = parse_qs(parsed.query)
context["session_id"] = query_params["session_id"][0]
endpoint_parsed.set()
continue

if (
is_structured_content
and b"event: message" in chunk
and b"structuredContent" in chunk
):
context["response"] = parse_sse_data_package(chunk)
stream_complete.set()
break
elif (
"result" in parse_sse_data_package(chunk)
and "content" in parse_sse_data_package(chunk)["result"]
):
context["response"] = parse_sse_data_package(chunk)
stream_complete.set()
break

task = asyncio.create_task(parse_stream())
await endpoint_parsed.wait()

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
},
json={
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-11-25",
"capabilities": {},
},
"id": request_id,
},
)

# Notification response is mandatory.
# https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
},
)

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": request_id,
},
)

await stream_complete.wait()
keep_sse_alive.set()

return task, context["session_id"], context["response"]

return inner


class MockServerRequestHandler(BaseHTTPRequestHandler):
def do_GET(self): # noqa: N802
# Process an HTTP GET request and return a response.
Expand Down
115 changes: 75 additions & 40 deletions tests/integrations/fastmcp/test_fastmcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
accurate testing of the integration's behavior in real MCP Server scenarios.
"""

import anyio
import asyncio
import json
import pytest
Expand All @@ -39,9 +40,12 @@ async def __call__(self, *args, **kwargs):
from sentry_sdk.consts import SPANDATA, OP
from sentry_sdk.integrations.mcp import MCPIntegration

from mcp.server.lowlevel import Server
from mcp.server.sse import SseServerTransport
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager

from starlette.routing import Mount
from starlette.routing import Mount, Route
from starlette.responses import Response
from starlette.applications import Starlette

# Try to import both FastMCP implementations
Expand Down Expand Up @@ -260,34 +264,6 @@ def reset_request_ctx():
pass


class MockRequestContext:
"""Mock MCP request context"""

def __init__(self, request_id=None, session_id=None, transport="stdio"):
self.request_id = request_id
if transport in ("http", "sse"):
self.request = MockHTTPRequest(session_id, transport)
else:
self.request = None


class MockHTTPRequest:
"""Mock HTTP request for SSE/StreamableHTTP transport"""

def __init__(self, session_id=None, transport="http"):
self.headers = {}
self.query_params = {}

if transport == "sse":
# SSE transport uses query parameter
if session_id:
self.query_params["session_id"] = session_id
else:
# StreamableHTTP transport uses header
if session_id:
self.headers["mcp-session-id"] = session_id


# =============================================================================
# Tool Handler Tests - Verifying Sentry Integration
# =============================================================================
Expand Down Expand Up @@ -1029,8 +1005,11 @@ def test_tool_no_ctx(x: int) -> dict:
# =============================================================================


@pytest.mark.asyncio
@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
def test_fastmcp_sse_transport(sentry_init, capture_events, FastMCP):
async def test_fastmcp_sse_transport(
sentry_init, capture_events, FastMCP, json_rpc_sse
):
"""Test that FastMCP correctly detects SSE transport"""
sentry_init(
integrations=[MCPIntegration()],
Expand All @@ -1039,25 +1018,81 @@ def test_fastmcp_sse_transport(sentry_init, capture_events, FastMCP):
events = capture_events()

mcp = FastMCP("Test Server")
sse = SseServerTransport("/messages/")

# Set up mock request context with SSE transport
if request_ctx is not None:
mock_ctx = MockRequestContext(
request_id="req-sse", session_id="session-sse-123", transport="sse"
)
request_ctx.set(mock_ctx)
sse_connection_closed = asyncio.Event()

async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
async with anyio.create_task_group() as tg:

async def run_server():
await mcp._mcp_server.run(
streams[0],
streams[1],
mcp._mcp_server.create_initialization_options(),
)

tg.start_soon(run_server)

sse_connection_closed.set()
return Response()

app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message),
],
)

@mcp.tool()
def sse_tool(value: str) -> dict:
"""Tool for SSE transport test"""
return {"message": f"Received: {value}"}

with start_transaction(name="fastmcp tx"):
result = call_tool_through_mcp(mcp, "sse_tool", {"value": "hello"})
keep_sse_alive = asyncio.Event()
app_task, _, result = await json_rpc_sse(
app,
method="tools/call",
params={
"name": "sse_tool",
"arguments": {"value": "hello"},
},
request_id="req-sse",
keep_sse_alive=keep_sse_alive,
)

assert result == {"message": "Received: hello"}
await sse_connection_closed.wait()
await app_task

(tx,) = events
if (
isinstance(mcp, StandaloneFastMCP)
and FASTMCP_VERSION is not None
and FASTMCP_VERSION.startswith("2")
):
assert result["result"]["content"][0]["text"] == json.dumps(
{"message": "Received: hello"}, separators=(",", ":")
)
elif (
isinstance(mcp, StandaloneFastMCP) and FASTMCP_VERSION is not None
): # Checking for None is not precise.
assert result["result"]["content"][0]["text"] == json.dumps(
{"message": "Received: hello"}
)
else:
assert result["result"]["content"][0]["text"] == json.dumps(
{"message": "Received: hello"}, indent=2
)

transactions = [
event
for event in events
if event["type"] == "transaction" and event["transaction"] == "/sse"
]
assert len(transactions) == 1
tx = transactions[0]

# Find MCP spans
mcp_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]
Expand Down
Loading
Loading