Skip to content

feat: Support cancellation notification while request timeout #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;

import com.fasterxml.jackson.core.type.TypeReference;
Expand Down Expand Up @@ -236,7 +237,18 @@ public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReferenc
this.pendingResponses.remove(requestId);
sink.error(error);
});
}).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> {
}).timeout(this.requestTimeout).onErrorResume(e -> {
if (e instanceof TimeoutException) {
return Mono.fromRunnable(() -> {
this.pendingResponses.remove(requestId);
McpSchema.CancellationMessageNotification cancellationMessageNotification = new McpSchema.CancellationMessageNotification(
requestId, "The request times out, timeout: " + requestTimeout.toMillis() + " ms");
sendNotification(McpSchema.METHOD_NOTIFICATION_CANCELLED, cancellationMessageNotification)
.subscribe();
}).then(Mono.error(e));
}
return Mono.error(e);
}).handle((jsonRpcResponse, sink) -> {
if (jsonRpcResponse.error() != null) {
sink.error(new McpError(jsonRpcResponse.error()));
}
Expand Down
13 changes: 13 additions & 0 deletions mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
* Context Protocol Schema</a>.
*
* @author Christian Tzolov
* @author Jermaine Hua
*/
public final class McpSchema {

Expand All @@ -50,6 +51,8 @@ private McpSchema() {

public static final String METHOD_NOTIFICATION_INITIALIZED = "notifications/initialized";

public static final String METHOD_NOTIFICATION_CANCELLED = "notifications/cancelled";

public static final String METHOD_PING = "ping";

// Tool Methods
Expand Down Expand Up @@ -211,6 +214,16 @@ public record JSONRPCError(
}
}// @formatter:on

// ---------------------------
// Cancellation Message Notification
// ---------------------------
@JsonInclude(JsonInclude.Include.NON_ABSENT)
@JsonIgnoreProperties(ignoreUnknown = true)
public record CancellationMessageNotification( // @formatter:off
@JsonProperty("requestId") String requestId,
@JsonProperty("reason") String reason){
} // @formatter:on

// ---------------------------
// Initialization
// ---------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ public class McpServerSession implements McpSession {

private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED);

/**
* keyed by request ID, value is true if the request is being cancelled.
*/
private final Map<Object, Boolean> requestCancellation = new ConcurrentHashMap<>();

/**
* Creates a new server session with the given parameters and the transport to use.
* @param id session id
Expand Down Expand Up @@ -165,13 +170,18 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
}
else if (message instanceof McpSchema.JSONRPCRequest request) {
logger.debug("Received request: {}", request);
requestCancellation.put(request.id(), false);
return handleIncomingRequest(request).onErrorResume(error -> {
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
error.getMessage(), null));
// TODO: Should the error go to SSE or back as POST return?
return this.transport.sendMessage(errorResponse).then(Mono.empty());
}).flatMap(this.transport::sendMessage);
return this.transport.sendMessage(errorResponse)
.doFinally(signal -> requestCancellation.remove(request.id()))
.then(Mono.empty());
})
.flatMap(response -> this.transport.sendMessage(response)
.doFinally(signal -> requestCancellation.remove(request.id())));
}
else if (message instanceof McpSchema.JSONRPCNotification notification) {
// TODO handle errors for communication to without initialization
Expand Down Expand Up @@ -207,6 +217,11 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
resultMono = this.initRequestHandler.handle(initializeRequest);
}
else {
// cancellation request
if (requestCancellation.get(request.id())) {
requestCancellation.remove(request.id());
return Mono.empty();
}
// TODO handle errors for communication to this session without
// initialization happening first
var handler = this.requestHandlers.get(request.method());
Expand All @@ -217,14 +232,32 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
error.message(), error.data())));
}

resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params()));
resultMono = this.exchangeSink.asMono()
.flatMap(exchange -> handler.handle(exchange, request.params()).flatMap(result -> {
if (requestCancellation.get(request.id())) {
requestCancellation.remove(request.id());
return Mono.empty();
}
else {
return Mono.just(result);
}
}).doOnCancel(() -> requestCancellation.remove(request.id())));

}
return resultMono
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
.onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
error.getMessage(), null)))); // TODO: add error message
// through the data field
.onErrorResume(error -> {
if (requestCancellation.get(request.id())) {
requestCancellation.remove(request.id());
return Mono.empty();
}
else {
return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
error.getMessage(), null)));
}
}); // TODO: add error message
// through the data field
});
}

Expand All @@ -240,6 +273,17 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get()));
return this.initNotificationHandler.handle();
}
else if (McpSchema.METHOD_NOTIFICATION_CANCELLED.equals(notification.method())) {
McpSchema.CancellationMessageNotification cancellationMessageNotification = transport
.unmarshalFrom(notification.params(), new TypeReference<>() {
});
if (requestCancellation.containsKey(cancellationMessageNotification.requestId())) {
logger.warn("Received cancellation notification for request {}, cancellation reason is {}",
cancellationMessageNotification.requestId(), cancellationMessageNotification.reason());
requestCancellation.put(cancellationMessageNotification.requestId(), true);
}
return Mono.empty();
}

var handler = notificationHandlers.get(notification.method());
if (handler == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import reactor.core.publisher.Sinks;
import reactor.test.StepVerifier;

import static io.modelcontextprotocol.spec.McpSchema.METHOD_NOTIFICATION_CANCELLED;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

Expand Down Expand Up @@ -119,6 +120,25 @@ void testRequestTimeout() {
.verify(TIMEOUT.plusSeconds(1));
}

@Test
void testCancellationMessageNotificationForRequestTimeout() {
Mono<String> responseMono = session.sendRequest(TEST_METHOD, "test", responseType);

StepVerifier.create(responseMono)
.expectError(java.util.concurrent.TimeoutException.class)
.verify(TIMEOUT.plusSeconds(1));

McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage();
assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class);
McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage;
assertThat(notification.method()).isEqualTo(METHOD_NOTIFICATION_CANCELLED);
McpSchema.CancellationMessageNotification cancellationMessageNotification = transport
.unmarshalFrom(notification.params(), new TypeReference<>() {
});
assertThat(cancellationMessageNotification.reason()
.contains("The request times out, timeout: " + TIMEOUT.toMillis() + " ms")).isTrue();
}

@Test
void testSendNotification() {
Map<String, Object> params = Map.of("key", "value");
Expand Down