diff --git a/spring-graphql/src/main/java/org/springframework/graphql/web/webflux/GraphQlWebSocketHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/web/webflux/GraphQlWebSocketHandler.java index c983e5a39..0b0798dda 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/web/webflux/GraphQlWebSocketHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/web/webflux/GraphQlWebSocketHandler.java @@ -246,7 +246,10 @@ private Flux handleWebOutput(WebSocketSession session, String .message(ex.getMessage()) .build() .toSpecification(); - return Mono.just(encode(session, id, MessageType.ERROR, errorMap)); + + // Payload needs to be an array + // see: https://github.com/enisdenjo/graphql-ws/blob/master/docs/interfaces/common.ErrorMessage.md#payload + return Mono.just(encode(session, id, MessageType.ERROR, Collections.singletonList(errorMap))); }); } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler.java index caaafe19d..ca7254a0e 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler.java @@ -260,7 +260,10 @@ private Flux handleWebOutput(WebSocketSession session, String id, W String message = ex.getMessage(); Map errorMap = GraphqlErrorBuilder.newError().errorType(errorType).message(message).build() .toSpecification(); - return Mono.just(encode(id, MessageType.ERROR, errorMap)); + + // Payload needs to be an array + // see: https://github.com/enisdenjo/graphql-ws/blob/master/docs/interfaces/common.ErrorMessage.md#payload + return Mono.just(encode(id, MessageType.ERROR, Collections.singletonList(errorMap))); }); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/web/webflux/GraphQlWebSocketHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/web/webflux/GraphQlWebSocketHandlerTests.java index c6a6340bf..5bfc3e939 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/web/webflux/GraphQlWebSocketHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/web/webflux/GraphQlWebSocketHandlerTests.java @@ -26,6 +26,8 @@ import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; +import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.web.WebGraphQlHandler; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; @@ -245,6 +247,60 @@ void clientCompletion() { .verifyTimeout(Duration.ofMillis(500)); } + @Test + void errorMessagePayloadIsCorrectArray() { + final String GREETING_QUERY = "{" + + "\"id\":\"" + SUBSCRIPTION_ID + "\"," + + "\"type\":\"subscribe\"," + + "\"payload\":{\"query\": \"" + + " subscription TestTypenameSubscription {" + + " greeting" + + " }\"}" + + "}"; + + WebGraphQlHandler initHandler = GraphQlSetup.schemaContent("" + + "type Subscription { greeting: String! }" + + "type Query { greetingUnused: String! }") + .subscriptionFetcher("greeting", env -> Flux.just("a", null, "b")) + .webInterceptor() + .toWebGraphQlHandler(); + + GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler( + initHandler, + ServerCodecConfigurer.create(), + Duration.ofSeconds(60)); + + TestWebSocketSession session = new TestWebSocketSession(Flux.just( + toWebSocketMessage("{\"type\":\"connection_init\"}"), + toWebSocketMessage(GREETING_QUERY))); + handler.handle(session).block(); + + StepVerifier.create(session.getOutput()) + .consumeNextWith((message) -> assertMessageType(message, "connection_ack")) + .consumeNextWith((message) -> assertThat(decode(message)) + .hasSize(3) + .containsEntry("id", SUBSCRIPTION_ID) + .containsEntry("type", "next") + .extractingByKey("payload", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .extractingByKey("data", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .containsEntry("greeting", "a")) + .consumeNextWith((message) -> assertThat(decode(message)) + .hasSize(3) + .containsEntry("id", SUBSCRIPTION_ID) + .containsEntry("type", "error") + .hasEntrySatisfying("payload", payload -> assertThat(payload) + .asList() + .hasSize(1) + .allSatisfy(theError -> assertThat(theError) + .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) + .hasSize(3) + .hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty()) + .hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("null")) + .extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .containsEntry("classification", "DataFetchingException")))) + .verifyComplete(); + } + private TestWebSocketSession handle(Flux input, WebInterceptor... interceptors) { GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler( initHandler(interceptors), diff --git a/spring-graphql/src/test/java/org/springframework/graphql/web/webmvc/GraphQlWebSocketHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/web/webmvc/GraphQlWebSocketHandlerTests.java index 52d1701e0..fbdc57f13 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/web/webmvc/GraphQlWebSocketHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/web/webmvc/GraphQlWebSocketHandlerTests.java @@ -29,6 +29,9 @@ import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; +import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.web.WebGraphQlHandler; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -249,6 +252,57 @@ void clientCompletion() throws Exception { .verifyTimeout(Duration.ofMillis(500)); } + @Test + void errorMessagePayloadIsCorrectArray() throws Exception { + final String GREETING_QUERY = "{" + + "\"id\":\"" + SUBSCRIPTION_ID + "\"," + + "\"type\":\"subscribe\"," + + "\"payload\":{\"query\": \"" + + " subscription TestTypenameSubscription {" + + " greeting" + + " }\"}" + + "}"; + + WebGraphQlHandler initHandler = GraphQlSetup.schemaContent("" + + "type Subscription { greeting: String! }" + + "type Query { greetingUnused: String! }") + .subscriptionFetcher("greeting", env -> Flux.just("a", null, "b")) + .webInterceptor() + .toWebGraphQlHandler(); + + GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(initHandler, converter, Duration.ofSeconds(60)); + + handle(handler, + new TextMessage("{\"type\":\"connection_init\"}"), + new TextMessage(GREETING_QUERY)); + + StepVerifier.create(this.session.getOutput()) + .consumeNextWith((message) -> assertMessageType(message, "connection_ack")) + .consumeNextWith((message) -> assertThat(decode(message)) + .hasSize(3) + .containsEntry("id", SUBSCRIPTION_ID) + .containsEntry("type", "next") + .extractingByKey("payload", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .extractingByKey("data", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .containsEntry("greeting", "a")) + .consumeNextWith((message) -> assertThat(decode(message)) + .hasSize(3) + .containsEntry("id", SUBSCRIPTION_ID) + .containsEntry("type", "error") + .hasEntrySatisfying("payload", payload -> assertThat(payload) + .asList() + .hasSize(1) + .allSatisfy(theError -> assertThat(theError) + .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) + .hasSize(3) + .hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty()) + .hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("null")) + .extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .containsEntry("classification", "DataFetchingException")))) + .then(this.session::close) + .verifyComplete(); + } + private void handle(GraphQlWebSocketHandler handler, TextMessage... textMessages) throws Exception { handler.afterConnectionEstablished(this.session); for (TextMessage message : textMessages) {