diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/SubscriptionProtocolFactory.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/SubscriptionProtocolFactory.java index ed2cddb3..3c1736a4 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/SubscriptionProtocolFactory.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/SubscriptionProtocolFactory.java @@ -16,4 +16,8 @@ public String getProtocol() { } public abstract Consumer createConsumer(SubscriptionSession session); + + public void shutdown() { + // do nothing + } } diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionConnectionListener.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionConnectionListener.java index 2745c4b3..95d52f5e 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionConnectionListener.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionConnectionListener.java @@ -20,4 +20,8 @@ default void onStop(SubscriptionSession session, OperationMessage message) { default void onTerminate(SubscriptionSession session, OperationMessage message) { // do nothing } -} + + default void shutdown() { + // do nothing + } +} \ No newline at end of file diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionKeepAliveRunner.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionKeepAliveRunner.java index 4ad8ee21..91949a51 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionKeepAliveRunner.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionKeepAliveRunner.java @@ -61,4 +61,8 @@ void abort(SubscriptionSession session) { future.cancel(true); } } + + void shutdown() { + this.executor.shutdown(); + } } diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionProtocolFactory.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionProtocolFactory.java index 0737f628..ec76b384 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionProtocolFactory.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/ApolloSubscriptionProtocolFactory.java @@ -19,6 +19,7 @@ public class ApolloSubscriptionProtocolFactory extends SubscriptionProtocolFacto public static final int KEEP_ALIVE_INTERVAL = 15; @Getter private final GraphQLObjectMapper objectMapper; private final ApolloCommandProvider commandProvider; + private KeepAliveSubscriptionConnectionListener keepAlive; public ApolloSubscriptionProtocolFactory( GraphQLObjectMapper objectMapper, @@ -67,7 +68,8 @@ public ApolloSubscriptionProtocolFactory( if (keepAliveInterval != null && listeners.stream() .noneMatch(KeepAliveSubscriptionConnectionListener.class::isInstance)) { - listeners.add(new KeepAliveSubscriptionConnectionListener(keepAliveInterval)); + keepAlive = new KeepAliveSubscriptionConnectionListener(keepAliveInterval); + listeners.add(keepAlive); } commandProvider = new ApolloCommandProvider( @@ -81,4 +83,11 @@ public ApolloSubscriptionProtocolFactory( public Consumer createConsumer(SubscriptionSession session) { return new ApolloSubscriptionConsumer(session, objectMapper, commandProvider); } + + @Override + public void shutdown() { + if (keepAlive != null) { + keepAlive.shutdown(); + } + } } diff --git a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/KeepAliveSubscriptionConnectionListener.java b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/KeepAliveSubscriptionConnectionListener.java index 4d347410..93c21984 100644 --- a/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/KeepAliveSubscriptionConnectionListener.java +++ b/graphql-java-kickstart/src/main/java/graphql/kickstart/execution/subscriptions/apollo/KeepAliveSubscriptionConnectionListener.java @@ -6,7 +6,7 @@ public class KeepAliveSubscriptionConnectionListener implements ApolloSubscriptionConnectionListener { - private final ApolloSubscriptionKeepAliveRunner keepAliveRunner; + protected final ApolloSubscriptionKeepAliveRunner keepAliveRunner; public KeepAliveSubscriptionConnectionListener() { this(Duration.ofSeconds(15)); @@ -35,4 +35,10 @@ public void onStop(SubscriptionSession session, OperationMessage message) { public void onTerminate(SubscriptionSession session, OperationMessage message) { keepAliveRunner.abort(session); } + + @Override + public void shutdown() { + keepAliveRunner.shutdown(); + } + } diff --git a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java index 4302cc00..7ff86408 100644 --- a/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java +++ b/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/GraphQLWebsocketServlet.java @@ -116,6 +116,24 @@ public GraphQLWebsocketServlet( .collect(toList()); } + public GraphQLWebsocketServlet( + GraphQLInvoker graphQLInvoker, + GraphQLSubscriptionInvocationInputFactory invocationInputFactory, + GraphQLObjectMapper graphQLObjectMapper, + List subscriptionProtocolFactory, + SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory) { + + this.subscriptionProtocolFactories = subscriptionProtocolFactory; + this.fallbackSubscriptionProtocolFactory = fallbackSubscriptionProtocolFactory; + + allSubscriptionProtocols = + Stream.concat( + subscriptionProtocolFactories.stream(), + Stream.of(fallbackSubscriptionProtocolFactory)) + .map(SubscriptionProtocolFactory::getProtocol) + .collect(toList()); + } + @Override public void onOpen(Session session, EndpointConfig endpointConfig) { final WebSocketSubscriptionProtocolFactory subscriptionProtocolFactory = @@ -234,6 +252,12 @@ public void beginShutDown() { log.error("GraphQLWebsocketServlet did not shut down cleanly!"); sessionSubscriptionCache.clear(); } + + for (SubscriptionProtocolFactory protocolFactory : subscriptionProtocolFactories) { + protocolFactory.shutdown(); + } + + fallbackSubscriptionProtocolFactory.shutdown(); } isShutDown.set(true);