Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 207 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import asyncio
from urllib.parse import urlparse, parse_qs
import socket
import warnings
import brotli
Expand Down Expand Up @@ -51,25 +53,40 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional
from typing import Any, Callable, MutableMapping, Optional
from collections.abc import Iterator

try:
from anyio import create_memory_object_stream, create_task_group
from anyio import create_memory_object_stream, create_task_group, EndOfStream
from mcp.types import (
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
)
from mcp.shared.message import SessionMessage
from httpx import (
ASGITransport,
Request as HttpxRequest,
Response as HttpxResponse,
AsyncByteStream,
AsyncClient,
)
except ImportError:
create_memory_object_stream = None
create_task_group = None
EndOfStream = None

JSONRPCMessage = None
JSONRPCNotification = None
JSONRPCRequest = None
SessionMessage = None

ASGITransport = None
HttpxRequest = None
HttpxResponse = None
AsyncByteStream = None
AsyncClient = None


SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json"

Expand Down Expand Up @@ -787,6 +804,194 @@ def inner(events):
return inner


@pytest.fixture()
def json_rpc_sse(is_structured_content: bool = True):
class StreamingASGITransport(ASGITransport):
"""
Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
tests involving SSE interactions to run in-process.
"""

def __init__(
self,
app: "Callable",
keep_sse_alive: "asyncio.Event",
) -> None:
self.keep_sse_alive = keep_sse_alive
super().__init__(app)

async def handle_async_request(
self, request: "HttpxRequest"
) -> "HttpxResponse":
scope = {
"type": "http",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"path": request.url.path,
"query_string": request.url.query,
}

is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
if not is_streaming_sse:
return await super().handle_async_request(request)

request_body = b""
if request.content:
request_body = await request.aread()

body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore

async def receive() -> "dict[str, Any]":
if self.keep_sse_alive.is_set():
return {"type": "http.disconnect"}

await self.keep_sse_alive.wait() # Keep alive :)
return {
"type": "http.request",
"body": request_body,
"more_body": False,
}

async def send(message: "MutableMapping[str, Any]") -> None:
if message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body == b"" and not more_body:
return

if body:
await body_sender.send(body)

if not more_body:
await body_sender.aclose()

async def run_app():
await self.app(scope, receive, send)

class StreamingBodyStream(AsyncByteStream): # type: ignore
def __init__(self, receiver):
self.receiver = receiver

async def __aiter__(self):
try:
async for chunk in self.receiver:
yield chunk
except EndOfStream: # type: ignore
pass

stream = StreamingBodyStream(body_receiver)
response = HttpxResponse(status_code=200, headers=[], stream=stream) # type: ignore

asyncio.create_task(run_app())
return response

def parse_sse_data_package(sse_chunk):
sse_text = sse_chunk.decode("utf-8")
json_str = sse_text.split("data: ")[1]
return json.loads(json_str)

async def inner(
app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
):
context = {}

stream_complete = asyncio.Event()
endpoint_parsed = asyncio.Event()

# https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
async with AsyncClient( # type: ignore
transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
base_url="http://test",
) as client:

async def parse_stream():
async with client.stream("GET", "/sse") as stream:
# Read directly from stream.stream instead of aiter_bytes()
async for chunk in stream.stream:
if b"event: endpoint" in chunk:
sse_text = chunk.decode("utf-8")
url = sse_text.split("data: ")[1]

parsed = urlparse(url)
query_params = parse_qs(parsed.query)
context["session_id"] = query_params["session_id"][0]
endpoint_parsed.set()
continue

if (
is_structured_content
and b"event: message" in chunk
and b"structuredContent" in chunk
):
context["response"] = parse_sse_data_package(chunk)
break
elif (
"result" in parse_sse_data_package(chunk)
and "content" in parse_sse_data_package(chunk)["result"]
):
context["response"] = parse_sse_data_package(chunk)
break

stream_complete.set()

task = asyncio.create_task(parse_stream())
await endpoint_parsed.wait()

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
},
json={
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-11-25",
"capabilities": {},
},
"id": request_id,
},
)

# Notification response is mandatory.
# https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
},
)

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": request_id,
},
)

await stream_complete.wait()
keep_sse_alive.set()

return task, context["session_id"], context["response"]

return inner


class MockServerRequestHandler(BaseHTTPRequestHandler):
def do_GET(self): # noqa: N802
# Process an HTTP GET request and return a response.
Expand Down
Loading
Loading