Skip to content

Make HttpRequestHandlerImpl.handle() async when possible (fixes #259) #268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 19, 2020
7 changes: 7 additions & 0 deletions .github/workflows/pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,34 @@ jobs:
name: Sonar analysis
needs: validation
runs-on: ubuntu-latest
env:
SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }}
steps:
- uses: actions/checkout@v2
if: env.SONAR_TOKEN != null
with:
fetch-depth: 0 # Shallow clones should be disabled for a better relevancy of analysis
- name: Set up JDK 11
if: env.SONAR_TOKEN != null
uses: actions/setup-java@v1
with:
java-version: 11
- name: Cache SonarCloud packages
if: env.SONAR_TOKEN != null
uses: actions/cache@v1
with:
path: ~/.sonar/cache
key: ${{ runner.os }}-sonar
restore-keys: ${{ runner.os }}-sonar
- name: Cache Gradle packages
if: env.SONAR_TOKEN != null
uses: actions/cache@v1
with:
path: ~/.gradle/caches
key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }}
restore-keys: ${{ runner.os }}-gradle
- name: Build and analyze
if: env.SONAR_TOKEN != null
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Needed to get PR information, if any
SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import graphql.kickstart.execution.input.GraphQLBatchedInvocationInput;
import graphql.kickstart.execution.input.GraphQLInvocationInput;
import graphql.kickstart.execution.input.GraphQLSingleInvocationInput;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import lombok.AllArgsConstructor;
Expand All @@ -28,26 +29,36 @@ public CompletableFuture<ExecutionResult> executeAsync(
}

public GraphQLQueryResult query(GraphQLInvocationInput invocationInput) {
return queryAsync(invocationInput).join();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I preserved the original signatures for binary compatibility.

This one is also still used by AbstractGraphQLHttpServlet.executeQuery() for JMX.

}

public CompletableFuture<GraphQLQueryResult> queryAsync(GraphQLInvocationInput invocationInput) {
if (invocationInput instanceof GraphQLSingleInvocationInput) {
return GraphQLQueryResult.create(query((GraphQLSingleInvocationInput) invocationInput));
return executeAsync((GraphQLSingleInvocationInput)invocationInput).thenApply(GraphQLQueryResult::create);
}
GraphQLBatchedInvocationInput batchedInvocationInput = (GraphQLBatchedInvocationInput) invocationInput;
return GraphQLQueryResult.create(query(batchedInvocationInput));
return executeAsync(batchedInvocationInput).thenApply(GraphQLQueryResult::create);
}

private ExecutionResult query(GraphQLSingleInvocationInput singleInvocationInput) {
return executeAsync(singleInvocationInput).join();
private CompletableFuture<List<ExecutionResult>> executeAsync(GraphQLBatchedInvocationInput batchedInvocationInput) {
GraphQL graphQL = batchedDataLoaderGraphQLBuilder.newGraphQL(batchedInvocationInput, graphQLBuilder);
return sequence(
batchedInvocationInput.getExecutionInputs().stream()
.map(executionInput -> proxy.executeAsync(graphQL, executionInput))
.collect(toList()));
}

private List<ExecutionResult> query(GraphQLBatchedInvocationInput batchedInvocationInput) {
GraphQL graphQL = batchedDataLoaderGraphQLBuilder
.newGraphQL(batchedInvocationInput, graphQLBuilder);
return batchedInvocationInput.getExecutionInputs().stream()
.map(executionInput -> proxy.executeAsync(graphQL, executionInput))
.collect(toList())
.stream()
.map(CompletableFuture::join)
.collect(toList());
@SuppressWarnings({"unchecked", "rawtypes"})
private <T> CompletableFuture<List<T>> sequence(List<CompletableFuture<T>> futures) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the JDK doesn't provide this common operation out of the box 😞
It's also a bit hacky to work around the vararg argument of allOf().

CompletableFuture[] futuresArray = futures.toArray(new CompletableFuture[0]);
return CompletableFuture.allOf(futuresArray).thenApply(aVoid -> {
List<T> result = new ArrayList<>(futures.size());
for (CompletableFuture future : futuresArray) {
assert future.isDone(); // per the API contract of allOf()
result.add((T) future.join());
}
return result;
});
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.servlet.AsyncContext;
import javax.servlet.Servlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -135,59 +134,38 @@ public String executeQuery(String query) {
}
}

private void doRequestAsync(HttpServletRequest request, HttpServletResponse response,
HttpRequestHandler handler) {
if (configuration.isAsyncServletModeEnabled()) {
AsyncContext asyncContext = request.startAsync(request, response);
asyncContext.setTimeout(configuration.getAsyncTimeout());
HttpServletRequest asyncRequest = (HttpServletRequest) asyncContext.getRequest();
HttpServletResponse asyncResponse = (HttpServletResponse) asyncContext.getResponse();
configuration.getAsyncExecutor()
.execute(() -> doRequest(asyncRequest, asyncResponse, handler, asyncContext));
} else {
doRequest(request, response, handler, null);
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
doRequest(req, resp);
}

private void doRequest(HttpServletRequest request, HttpServletResponse response,
HttpRequestHandler handler,
AsyncContext asyncContext) {
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
doRequest(req, resp);
}

private void doRequest(HttpServletRequest request, HttpServletResponse response) {
init();
List<GraphQLServletListener.RequestCallback> requestCallbacks = runListeners(
l -> l.onRequest(request, response));

try {
handler.handle(request, response);
requestHandler.handle(request, response);
runCallbacks(requestCallbacks, c -> c.onSuccess(request, response));
} catch (Throwable t) {
} catch (Exception t) {
log.error("Error executing GraphQL request!", t);
runCallbacks(requestCallbacks, c -> c.onError(request, response, t));
} finally {
runCallbacks(requestCallbacks, c -> c.onFinally(request, response));
if (asyncContext != null) {
asyncContext.complete();
}
}
}

@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
init();
doRequestAsync(req, resp, requestHandler);
}

@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
init();
doRequestAsync(req, resp, requestHandler);
}

private <R> List<R> runListeners(Function<? super GraphQLServletListener, R> action) {
return configuration.getListeners().stream()
.map(listener -> {
try {
return action.apply(listener);
} catch (Throwable t) {
} catch (Exception t) {
log.error("Error running listener: {}", listener, t);
return null;
}
Expand All @@ -200,7 +178,7 @@ private <T> void runCallbacks(List<T> callbacks, Consumer<T> action) {
callbacks.forEach(callback -> {
try {
action.accept(callback);
} catch (Throwable t) {
} catch (Exception t) {
log.error("Error running callback: {}", callback, t);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,19 @@ public class GraphQLConfiguration {
private final GraphQLInvoker graphQLInvoker;
private final GraphQLObjectMapper objectMapper;
private final List<GraphQLServletListener> listeners;
/**
* For removal
* @since 10.1.0
*/
@Deprecated
private final boolean asyncServletModeEnabled;
/**
* For removal
* @since 10.1.0
*/
@Deprecated
private final Executor asyncExecutor;

private final long subscriptionTimeout;
@Getter
private final long asyncTimeout;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package graphql.kickstart.servlet;

import com.fasterxml.jackson.core.JsonProcessingException;
import graphql.GraphQLException;
import graphql.kickstart.execution.GraphQLInvoker;
import graphql.kickstart.execution.GraphQLQueryResult;
Expand All @@ -9,6 +10,9 @@
import graphql.kickstart.servlet.input.BatchInputPreProcessResult;
import graphql.kickstart.servlet.input.BatchInputPreProcessor;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.concurrent.CompletableFuture;
import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -36,11 +40,11 @@ public void handle(HttpServletRequest request, HttpServletResponse response) thr
GraphQLInvocationInput invocationInput = invocationInputParser
.getGraphQLInvocationInput(request, response);
execute(invocationInput, request, response);
} catch (GraphQLException e) {
} catch (GraphQLException| JsonProcessingException e) {
response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad request: cannot handle http request", e);
throw e;
} catch (Throwable t) {
} catch (Exception t) {
response.setStatus(500);
log.error("Cannot handle http request", t);
throw t;
Expand All @@ -49,38 +53,65 @@ public void handle(HttpServletRequest request, HttpServletResponse response) thr

protected void execute(GraphQLInvocationInput invocationInput, HttpServletRequest request,
HttpServletResponse response) throws IOException {
GraphQLQueryResult queryResult = invoke(invocationInput, request, response);
if (request.isAsyncSupported()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated comment. as you are touching this already.
Do you mind changing line 41 to GraphQLException | IOException e? invocationInputParser.getGraphQLInvocationInput throws IOException (JsonProcessingException or JsonMappingException specifically) in case of unparseable input request. That results in #258

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those the only IOException instances that can happen here? Just making sure that we're not also catching others that would qualify as legitimate server errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Official method signature throws IOException. So it's hard to guarantee that. The use case right now is that servlet returns http 500 for bad requests that can't be deserialized. And that makes external attacks last longer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added JsonProcessingException instead of IOException.

AsyncContext asyncContext = request.isAsyncStarted()
? request.getAsyncContext()
: request.startAsync(request, response);
asyncContext.setTimeout(configuration.getAsyncTimeout());
invoke(invocationInput, request, response)
.thenAccept(result -> writeResultResponse(invocationInput, result, request, response))
.exceptionally(t -> writeErrorResponse(t, response))
.thenAccept(aVoid -> asyncContext.complete());
} else {
try {
GraphQLQueryResult result = invoke(invocationInput, request, response).join();
writeResultResponse(invocationInput, result, request, response);
} catch (Exception t) {
writeErrorResponse(t, response);
}
}
}

private void writeResultResponse(GraphQLInvocationInput invocationInput, GraphQLQueryResult queryResult, HttpServletRequest request,
HttpServletResponse response) {
QueryResponseWriter queryResponseWriter = createWriter(invocationInput, queryResult);
queryResponseWriter.write(request, response);
try {
queryResponseWriter.write(request, response);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

protected QueryResponseWriter createWriter(GraphQLInvocationInput invocationInput,
GraphQLQueryResult queryResult) {
protected QueryResponseWriter createWriter(GraphQLInvocationInput invocationInput, GraphQLQueryResult queryResult) {
return QueryResponseWriter.createWriter(queryResult, configuration.getObjectMapper(),
configuration.getSubscriptionTimeout());
}

private GraphQLQueryResult invoke(GraphQLInvocationInput invocationInput,
HttpServletRequest request,
private Void writeErrorResponse(Throwable t, HttpServletResponse response) {
response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given", t);
return null;
}

private CompletableFuture<GraphQLQueryResult> invoke(GraphQLInvocationInput invocationInput, HttpServletRequest request,
HttpServletResponse response) {
if (invocationInput instanceof GraphQLSingleInvocationInput) {
return graphQLInvoker.query(invocationInput);
return graphQLInvoker.queryAsync(invocationInput);
}
return invokeBatched((GraphQLBatchedInvocationInput) invocationInput, request, response);
}

private GraphQLQueryResult invokeBatched(GraphQLBatchedInvocationInput batchedInvocationInput,
private CompletableFuture<GraphQLQueryResult> invokeBatched(GraphQLBatchedInvocationInput batchedInvocationInput,
HttpServletRequest request,
HttpServletResponse response) {
BatchInputPreProcessor preprocessor = configuration.getBatchInputPreProcessor();
BatchInputPreProcessResult result = preprocessor
.preProcessBatch(batchedInvocationInput, request, response);
if (result.isExecutable()) {
return graphQLInvoker.query(result.getBatchedInvocationInput());
return graphQLInvoker.queryAsync(result.getBatchedInvocationInput());
}

return GraphQLQueryResult.createError(result.getStatusCode(), result.getStatusMessage());
return CompletableFuture.completedFuture(GraphQLQueryResult.createError(result.getStatusCode(), result.getStatusMessage()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ protected void execute(GraphQLInvocationInput invocationInput, HttpServletReques
}
}

@Override
protected QueryResponseWriter createWriter(GraphQLInvocationInput invocationInput,
GraphQLQueryResult queryResult) {
return CachingQueryResponseWriter.createCacheWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class AbstractGraphQLHttpServletSpec extends Specification {
getResponseContent().data.echo == "special char á"
}

def "async query over HTTP GET starts async request"() {
def "disabling async support on request over HTTP GET does not start async request"() {
setup:
servlet = TestUtils.createDefaultServlet({ env -> env.arguments.arg }, { env -> env.arguments.arg }, { env ->
AtomicReference<SingleSubscriberPublisher<String>> publisherRef = new AtomicReference<>();
Expand All @@ -138,12 +138,13 @@ class AbstractGraphQLHttpServletSpec extends Specification {
return publisherRef.get()
}, true)
request.addParameter('query', 'query { echo(arg:"test") }')
request.setAsyncSupported(false)

when:
servlet.doGet(request, response)

then:
request.asyncStarted == true
request.asyncContext == null
}

def "query over HTTP GET with variables returns data"() {
Expand Down Expand Up @@ -442,7 +443,7 @@ class AbstractGraphQLHttpServletSpec extends Specification {
getResponseContent().data.echo == "test"
}

def "async query over HTTP POST starts async request"() {
def "disabling async support on request over HTTP POST does not start async request"() {
setup:
servlet = TestUtils.createDefaultServlet({ env -> env.arguments.arg }, { env -> env.arguments.arg }, { env ->
AtomicReference<SingleSubscriberPublisher<String>> publisherRef = new AtomicReference<>();
Expand All @@ -455,12 +456,13 @@ class AbstractGraphQLHttpServletSpec extends Specification {
request.setContent(mapper.writeValueAsBytes([
query: 'query { echo(arg:"test") }'
]))
request.setAsyncSupported(false)

when:
servlet.doPost(request, response)

then:
request.asyncStarted == true
request.asyncContext == null
}

def "query over HTTP POST body with graphql contentType returns data"() {
Expand Down
Loading