diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 0895e02b..4011ebc0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -8,6 +8,7 @@ import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import com.fasterxml.jackson.core.type.TypeReference; @@ -236,7 +237,18 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc this.pendingResponses.remove(requestId); sink.error(error); }); - }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { + }).timeout(this.requestTimeout).onErrorResume(e -> { + if (e instanceof TimeoutException) { + return Mono.fromRunnable(() -> { + this.pendingResponses.remove(requestId); + McpSchema.CancellationMessageNotification cancellationMessageNotification = new McpSchema.CancellationMessageNotification( + requestId, "The request times out, timeout: " + requestTimeout.toMillis() + " ms"); + sendNotification(McpSchema.METHOD_NOTIFICATION_CANCELLED, cancellationMessageNotification) + .subscribe(); + }).then(Mono.error(e)); + } + return Mono.error(e); + }).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { sink.error(new McpError(jsonRpcResponse.error())); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e621ac19..c27ab809 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -29,6 +29,7 @@ * Context Protocol Schema. * * @author Christian Tzolov + * @author Jermaine Hua */ public final class McpSchema { @@ -50,6 +51,8 @@ private McpSchema() { public static final String METHOD_NOTIFICATION_INITIALIZED = "notifications/initialized"; + public static final String METHOD_NOTIFICATION_CANCELLED = "notifications/cancelled"; + public static final String METHOD_PING = "ping"; // Tool Methods @@ -211,6 +214,16 @@ public record JSONRPCError( } }// @formatter:on + // --------------------------- + // Cancellation Message Notification + // --------------------------- + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CancellationMessageNotification( // @formatter:off + @JsonProperty("requestId") String requestId, + @JsonProperty("reason") String reason){ + } // @formatter:on + // --------------------------- // Initialization // --------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 46014af8..6b22ae71 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -53,6 +53,11 @@ public class McpServerSession implements McpSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + /** + * keyed by request ID, value is true if the request is being cancelled. + */ + private final Map requestCancellation = new ConcurrentHashMap<>(); + /** * Creates a new server session with the given parameters and the transport to use. * @param id session id @@ -165,13 +170,18 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); + requestCancellation.put(request.id(), false); return handleIncomingRequest(request).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); // TODO: Should the error go to SSE or back as POST return? - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage); + return this.transport.sendMessage(errorResponse) + .doFinally(signal -> requestCancellation.remove(request.id())) + .then(Mono.empty()); + }) + .flatMap(response -> this.transport.sendMessage(response) + .doFinally(signal -> requestCancellation.remove(request.id()))); } else if (message instanceof McpSchema.JSONRPCNotification notification) { // TODO handle errors for communication to without initialization @@ -207,6 +217,11 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR resultMono = this.initRequestHandler.handle(initializeRequest); } else { + // cancellation request + if (requestCancellation.get(request.id())) { + requestCancellation.remove(request.id()); + return Mono.empty(); + } // TODO handle errors for communication to this session without // initialization happening first var handler = this.requestHandlers.get(request.method()); @@ -217,14 +232,32 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + resultMono = this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(exchange, request.params()).flatMap(result -> { + if (requestCancellation.get(request.id())) { + requestCancellation.remove(request.id()); + return Mono.empty(); + } + else { + return Mono.just(result); + } + }).doOnCancel(() -> requestCancellation.remove(request.id()))); + } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .onErrorResume(error -> { + if (requestCancellation.get(request.id())) { + requestCancellation.remove(request.id()); + return Mono.empty(); + } + else { + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null))); + } + }); // TODO: add error message + // through the data field }); } @@ -240,6 +273,17 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); return this.initNotificationHandler.handle(); } + else if (McpSchema.METHOD_NOTIFICATION_CANCELLED.equals(notification.method())) { + McpSchema.CancellationMessageNotification cancellationMessageNotification = transport + .unmarshalFrom(notification.params(), new TypeReference<>() { + }); + if (requestCancellation.containsKey(cancellationMessageNotification.requestId())) { + logger.warn("Received cancellation notification for request {}, cancellation reason is {}", + cancellationMessageNotification.requestId(), cancellationMessageNotification.reason()); + requestCancellation.put(cancellationMessageNotification.requestId(), true); + } + return Mono.empty(); + } var handler = notificationHandlers.get(notification.method()); if (handler == null) { diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index f72be43e..4900f806 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -18,6 +18,7 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.spec.McpSchema.METHOD_NOTIFICATION_CANCELLED; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -119,6 +120,25 @@ void testRequestTimeout() { .verify(TIMEOUT.plusSeconds(1)); } + @Test + void testCancellationMessageNotificationForRequestTimeout() { + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + StepVerifier.create(responseMono) + .expectError(java.util.concurrent.TimeoutException.class) + .verify(TIMEOUT.plusSeconds(1)); + + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); + McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; + assertThat(notification.method()).isEqualTo(METHOD_NOTIFICATION_CANCELLED); + McpSchema.CancellationMessageNotification cancellationMessageNotification = transport + .unmarshalFrom(notification.params(), new TypeReference<>() { + }); + assertThat(cancellationMessageNotification.reason() + .contains("The request times out, timeout: " + TIMEOUT.toMillis() + " ms")).isTrue(); + } + @Test void testSendNotification() { Map params = Map.of("key", "value");