|
14 | 14 | from mcp.server.fastmcp import Context, FastMCP |
15 | 15 | from mcp.server.fastmcp.prompts.base import UserMessage |
16 | 16 | from mcp.server.session import ServerSession |
| 17 | +from mcp.server.streamable_http import EventCallback, EventMessage, EventStore |
17 | 18 | from mcp.types import ( |
18 | 19 | AudioContent, |
19 | 20 | Completion, |
20 | 21 | CompletionArgument, |
21 | 22 | CompletionContext, |
22 | 23 | EmbeddedResource, |
23 | 24 | ImageContent, |
| 25 | + JSONRPCMessage, |
24 | 26 | PromptReference, |
25 | 27 | ResourceTemplateReference, |
26 | 28 | SamplingMessage, |
|
31 | 33 |
|
32 | 34 | logger = logging.getLogger(__name__) |
33 | 35 |
|
| 36 | +# Type aliases for event store |
| 37 | +StreamId = str |
| 38 | +EventId = str |
| 39 | + |
| 40 | + |
| 41 | +class InMemoryEventStore(EventStore): |
| 42 | + """Simple in-memory event store for SSE resumability testing.""" |
| 43 | + |
| 44 | + def __init__(self) -> None: |
| 45 | + self._events: list[tuple[StreamId, EventId, JSONRPCMessage | None]] = [] |
| 46 | + self._event_id_counter = 0 |
| 47 | + |
| 48 | + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: |
| 49 | + """Store an event and return its ID.""" |
| 50 | + self._event_id_counter += 1 |
| 51 | + event_id = str(self._event_id_counter) |
| 52 | + self._events.append((stream_id, event_id, message)) |
| 53 | + return event_id |
| 54 | + |
| 55 | + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: |
| 56 | + """Replay events after the specified ID.""" |
| 57 | + target_stream_id = None |
| 58 | + for stream_id, event_id, _ in self._events: |
| 59 | + if event_id == last_event_id: |
| 60 | + target_stream_id = stream_id |
| 61 | + break |
| 62 | + if target_stream_id is None: |
| 63 | + return None |
| 64 | + last_event_id_int = int(last_event_id) |
| 65 | + for stream_id, event_id, message in self._events: |
| 66 | + if stream_id == target_stream_id and int(event_id) > last_event_id_int: |
| 67 | + # Skip priming events (None message) |
| 68 | + if message is not None: |
| 69 | + await send_callback(EventMessage(message, event_id)) |
| 70 | + return target_stream_id |
| 71 | + |
| 72 | + |
34 | 73 | # Test data |
35 | 74 | TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" |
36 | 75 | TEST_AUDIO_BASE64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAAB9AAACABAAZGF0YQIAAAA=" |
|
39 | 78 | resource_subscriptions: set[str] = set() |
40 | 79 | watched_resource_content = "Watched resource content" |
41 | 80 |
|
| 81 | +# Create event store for SSE resumability (SEP-1699) |
| 82 | +event_store = InMemoryEventStore() |
| 83 | + |
42 | 84 | mcp = FastMCP( |
43 | 85 | name="mcp-conformance-test-server", |
| 86 | + event_store=event_store, |
| 87 | + retry_interval=100, # 100ms retry interval for SSE polling |
44 | 88 | ) |
45 | 89 |
|
46 | 90 |
|
@@ -263,6 +307,19 @@ def test_error_handling() -> str: |
263 | 307 | raise RuntimeError("This tool intentionally returns an error for testing") |
264 | 308 |
|
265 | 309 |
|
| 310 | +@mcp.tool() |
| 311 | +async def test_reconnection(ctx: Context[ServerSession, None]) -> str: |
| 312 | + """Tests SSE polling by closing stream mid-call (SEP-1699)""" |
| 313 | + await ctx.info("Before disconnect") |
| 314 | + |
| 315 | + await ctx.close_sse_stream() |
| 316 | + |
| 317 | + await asyncio.sleep(0.2) # Wait for client to reconnect |
| 318 | + |
| 319 | + await ctx.info("After reconnect") |
| 320 | + return "Reconnection test completed" |
| 321 | + |
| 322 | + |
266 | 323 | # Resources |
267 | 324 | @mcp.resource("test://static-text") |
268 | 325 | def static_text_resource() -> str: |
|
0 commit comments