From b9e7d95c81f963e291ce61140eb9c61d46fd2776 Mon Sep 17 00:00:00 2001 From: JermaineHua <crazyhzm@apache.org> Date: Sun, 20 Apr 2025 23:48:37 +0800 Subject: [PATCH 1/3] Fix transport session id mismatch with sessionId Signed-off-by: JermaineHua <crazyhzm@apache.org> --- .../server/transport/WebFluxSseServerTransportProvider.java | 5 +++-- .../server/transport/WebMvcSseServerTransportProvider.java | 2 +- .../java/io/modelcontextprotocol/server/McpAsyncServer.java | 6 +++--- .../transport/HttpServletSseServerTransportProvider.java | 2 +- .../server/transport/StdioServerTransportProvider.java | 4 +++- .../java/io/modelcontextprotocol/spec/McpServerSession.java | 3 ++- .../MockMcpServerTransportProvider.java | 4 +++- .../server/transport/StdioServerTransportProviderTests.java | 4 ++-- 8 files changed, 18 insertions(+), 12 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..c5cf2a8d 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -2,6 +2,7 @@ import java.io.IOException; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; @@ -261,8 +262,8 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) { .body(Flux.<ServerSentEvent<?>>create(sink -> { WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); - McpServerSession session = sessionFactory.create(sessionTransport); - String sessionId = session.getId(); + String sessionId = UUID.randomUUID().toString(); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); logger.debug("Created new SSE connection for session: {}", sessionId); sessions.put(sessionId, session); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..79973a98 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -263,7 +263,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { }); WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); this.sessions.put(sessionId, session); try { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 906cb9a0..b89e844c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -340,9 +340,9 @@ private static class AsyncServerImpl extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider + .setSessionFactory((id, transport) -> new McpServerSession(id, requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff47..40215089 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -219,7 +219,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) writer); // Create a new session using the session factory - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); this.sessions.put(sessionId, session); // Send initial endpoint event diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da977..297c1b2f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -12,6 +12,7 @@ import java.io.Reader; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.UUID; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -94,7 +95,8 @@ public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream input public void setSessionFactory(McpServerSession.Factory sessionFactory) { // Create a single session for the stdio connection var transport = new StdioMcpSessionTransport(); - this.session = sessionFactory.create(transport); + String sessionId = UUID.randomUUID().toString(); + this.session = sessionFactory.create(sessionId, transport); transport.initProcessing(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 64315095..47607165 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -342,10 +342,11 @@ public interface Factory { /** * Creates a new 1:1 representation of the client-server interaction. + * @param sessionId the id of the session. * @param sessionTransport the transport to use for communication with the client. * @return a new server session. */ - McpServerSession create(McpServerTransport sessionTransport); + McpServerSession create(String sessionId, McpServerTransport sessionTransport); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf..d60d7120 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -16,6 +16,7 @@ package io.modelcontextprotocol; import java.util.Map; +import java.util.UUID; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -43,7 +44,8 @@ public MockMcpServerTransport getTransport() { @Override public void setSessionFactory(Factory sessionFactory) { - session = sessionFactory.create(transport); + String sessionId = UUID.randomUUID().toString(); + session = sessionFactory.create(sessionId, transport); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5a..f447a1bd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -71,7 +71,7 @@ void setUp() { sessionFactory = mock(McpServerSession.Factory.class); // Configure mock behavior - when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(sessionFactory.create(any(), any(McpServerTransport.class))).thenReturn(mockSession); when(mockSession.closeGracefully()).thenReturn(Mono.empty()); when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); @@ -110,7 +110,7 @@ void shouldHandleIncomingMessages() throws Exception { AtomicReference<McpSchema.JSONRPCMessage> capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); - McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession.Factory realSessionFactory = (id, transport) -> { McpServerSession session = mock(McpServerSession.class); when(session.handle(any())).thenAnswer(invocation -> { capturedMessage.set(invocation.getArgument(0)); From 499d114eea21f5351a04432a47b5dcb07dc5ec2e Mon Sep 17 00:00:00 2001 From: JermaineHua <crazyhzm@apache.org> Date: Fri, 9 May 2025 22:21:31 +0800 Subject: [PATCH 2/3] Support generateId method for McpServerSession.Factory Signed-off-by: JermaineHua <crazyhzm@apache.org> --- .../transport/WebFluxSseServerTransportProvider.java | 2 +- .../transport/WebMvcSseServerTransportProvider.java | 2 +- .../transport/HttpServletSseServerTransportProvider.java | 2 +- .../server/transport/StdioServerTransportProvider.java | 2 +- .../io/modelcontextprotocol/spec/McpServerSession.java | 9 +++++++++ .../MockMcpServerTransportProvider.java | 2 +- .../transport/StdioServerTransportProviderTests.java | 1 + 7 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index c5cf2a8d..93cb10d0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -262,7 +262,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) { .body(Flux.<ServerSentEvent<?>>create(sink -> { WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); - String sessionId = UUID.randomUUID().toString(); + String sessionId = sessionFactory.generateId(); McpServerSession session = sessionFactory.create(sessionId, sessionTransport); logger.debug("Created new SSE connection for session: {}", sessionId); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 79973a98..879129f3 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -247,7 +247,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - String sessionId = UUID.randomUUID().toString(); + String sessionId = sessionFactory.generateId(); logger.debug("Creating new SSE connection for session: {}", sessionId); // Send initial endpoint event diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 40215089..f8ea97a4 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -208,7 +208,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) response.setHeader("Connection", "keep-alive"); response.setHeader("Access-Control-Allow-Origin", "*"); - String sessionId = UUID.randomUUID().toString(); + String sessionId = sessionFactory.generateId(); AsyncContext asyncContext = request.startAsync(); asyncContext.setTimeout(0); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 297c1b2f..4c5015fb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -95,7 +95,7 @@ public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream input public void setSessionFactory(McpServerSession.Factory sessionFactory) { // Create a single session for the stdio connection var transport = new StdioMcpSessionTransport(); - String sessionId = UUID.randomUUID().toString(); + String sessionId = sessionFactory.generateId(); this.session = sessionFactory.create(sessionId, transport); transport.initProcessing(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 47607165..c1056634 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -2,6 +2,7 @@ import java.time.Duration; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -348,6 +349,14 @@ public interface Factory { */ McpServerSession create(String sessionId, McpServerTransport sessionTransport); + /** + * Generates a unique session id. + * @return a unique session id. + */ + default String generateId() { + return UUID.randomUUID().toString(); + } + } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index d60d7120..5369118a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -44,7 +44,7 @@ public MockMcpServerTransport getTransport() { @Override public void setSessionFactory(Factory sessionFactory) { - String sessionId = UUID.randomUUID().toString(); + String sessionId = sessionFactory.generateId(); session = sessionFactory.create(sessionId, transport); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index f447a1bd..5d33fed7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -10,6 +10,7 @@ import java.io.PrintStream; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; From f67abae86434bf2a571c43d355acada9485619bc Mon Sep 17 00:00:00 2001 From: JermaineHua <crazyhzm@apache.org> Date: Fri, 9 May 2025 22:26:05 +0800 Subject: [PATCH 3/3] Format the code style Signed-off-by: JermaineHua <crazyhzm@apache.org> --- .../java/io/modelcontextprotocol/server/McpAsyncServer.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index bc2c8937..3f9bef29 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -183,9 +183,8 @@ public class McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - (id, transport) -> new McpServerSession(id, requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory((id, transport) -> new McpServerSession(id, requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } // ---------------------------------------