diff --git a/client/src/main/java/io/avaje/http/client/DHttpClientContext.java b/client/src/main/java/io/avaje/http/client/DHttpClientContext.java index a1f06d5..6e79ab4 100644 --- a/client/src/main/java/io/avaje/http/client/DHttpClientContext.java +++ b/client/src/main/java/io/avaje/http/client/DHttpClientContext.java @@ -11,7 +11,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; import java.util.concurrent.atomic.LongAccumulator; import java.util.concurrent.atomic.LongAdder; @@ -33,6 +38,8 @@ final class DHttpClientContext implements HttpClientContext { private final boolean withAuthToken; private final AuthTokenProvider authTokenProvider; private final AtomicReference tokenRef = new AtomicReference<>(); + private final Executor executor; + private final AtomicLong activeAsync = new AtomicLong(); private int loggingMaxBody = 1_000; private final LongAdder metricResTotal = new LongAdder(); @@ -41,7 +48,7 @@ final class DHttpClientContext implements HttpClientContext { private final LongAdder metricResMicros = new LongAdder(); private final LongAccumulator metricResMaxMicros = new LongAccumulator(Math::max, 0); - DHttpClientContext(HttpClient httpClient, String baseUrl, Duration requestTimeout, BodyAdapter bodyAdapter, RetryHandler retryHandler, RequestListener requestListener, AuthTokenProvider authTokenProvider, RequestIntercept intercept) { + DHttpClientContext(HttpClient httpClient, String baseUrl, Duration requestTimeout, BodyAdapter bodyAdapter, RetryHandler retryHandler, RequestListener requestListener, AuthTokenProvider authTokenProvider, RequestIntercept intercept, Executor executor) { this.httpClient = httpClient; this.baseUrl = baseUrl; this.requestTimeout = requestTimeout; @@ -51,6 +58,7 @@ final class DHttpClientContext implements HttpClientContext { this.authTokenProvider = authTokenProvider; this.withAuthToken = authTokenProvider != null; this.requestIntercept = intercept; + this.executor = executor; } @Override @@ -255,7 +263,13 @@ HttpResponse send(HttpRequest.Builder requestBuilder, HttpResponse.BodyHa } CompletableFuture> sendAsync(HttpRequest.Builder requestBuilder, HttpResponse.BodyHandler bodyHandler) { - return httpClient.sendAsync(requestBuilder.build(), bodyHandler); + activeAsync.incrementAndGet(); + if (executor == null) { + // defaults to ForkJoinPool.commonPool() + return httpClient.sendAsync(requestBuilder.build(), bodyHandler); + } else { + return httpClient.sendAsync(requestBuilder.build(), bodyHandler).whenCompleteAsync((r, t)-> {}, executor); + } } BodyContent write(T bean, String contentType) { @@ -274,6 +288,37 @@ List readList(Class cls, BodyContent content) { return bodyAdapter.listReader(cls).read(content); } + @Override + public boolean shutdown(long timeout, TimeUnit timeUnit) { + long timeoutMillis = TimeUnit.MILLISECONDS.convert(timeout, timeUnit); + if (!waitForActiveAsync(timeoutMillis)) { + return false; + } + if (executor instanceof ExecutorService) { + ExecutorService es = (ExecutorService)executor; + es.shutdown(); + try { + return es.awaitTermination(timeout, timeUnit); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + System.getLogger("io.avaje.http.client").log(System.Logger.Level.WARNING, "Interrupt on shutdown", e); + return false; + } + } + return true; + } + + private boolean waitForActiveAsync(long millis) { + final long until = System.currentTimeMillis() + millis; + do { + if (activeAsync.get() <= 0) { + return true; + } + LockSupport.parkNanos(10_000_000); + } while (System.currentTimeMillis() < until); + return false; + } + void afterResponse(DHttpClientRequest request) { metricResTotal.add(1); metricResMicros.add(request.responseTimeMicros()); @@ -287,6 +332,9 @@ void afterResponse(DHttpClientRequest request) { if (requestIntercept != null) { requestIntercept.afterResponse(request.response(), request); } + if (request.startAsyncNanos > 0) { + activeAsync.decrementAndGet(); + } } void beforeRequest(DHttpClientRequest request) { diff --git a/client/src/main/java/io/avaje/http/client/DHttpClientContextBuilder.java b/client/src/main/java/io/avaje/http/client/DHttpClientContextBuilder.java index f9e390e..fdbb2aa 100644 --- a/client/src/main/java/io/avaje/http/client/DHttpClientContextBuilder.java +++ b/client/src/main/java/io/avaje/http/client/DHttpClientContextBuilder.java @@ -159,7 +159,7 @@ public HttpClientContext build() { // register the built in request/response logging requestListener(new RequestLogger()); } - return new DHttpClientContext(client, baseUrl, requestTimeout, bodyAdapter, retryHandler, buildListener(), authTokenProvider, buildIntercept()); + return new DHttpClientContext(client, baseUrl, requestTimeout, bodyAdapter, retryHandler, buildListener(), authTokenProvider, buildIntercept(), executor); } private RequestListener buildListener() { diff --git a/client/src/main/java/io/avaje/http/client/DHttpClientRequest.java b/client/src/main/java/io/avaje/http/client/DHttpClientRequest.java index 3c080ab..0104b3e 100644 --- a/client/src/main/java/io/avaje/http/client/DHttpClientRequest.java +++ b/client/src/main/java/io/avaje/http/client/DHttpClientRequest.java @@ -50,7 +50,7 @@ class DHttpClientRequest implements HttpClientRequest, HttpClientResponse { private boolean loggableResponseBody; private boolean skipAuthToken; private boolean suppressLogging; - private long startAsyncNanos; + protected long startAsyncNanos; private String label; private Map customAttributes; diff --git a/client/src/main/java/io/avaje/http/client/HttpClientContext.java b/client/src/main/java/io/avaje/http/client/HttpClientContext.java index d65bf01..556244e 100644 --- a/client/src/main/java/io/avaje/http/client/HttpClientContext.java +++ b/client/src/main/java/io/avaje/http/client/HttpClientContext.java @@ -9,6 +9,7 @@ import java.net.http.HttpResponse; import java.time.Duration; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; /** * The HTTP client context that we use to build and process requests. @@ -129,6 +130,19 @@ static HttpClientContext.Builder newBuilder() { */ byte[] decodeContent(String encoding, byte[] content); + /** + * When this context is created with an Executor and that is an ExecutorService + * then this will wait for async requests to be processed and then shutdown the + * ExecutorService. + * + * @param timeout The maximum time to wait for async processes to complete + * @param timeUnit The time unit for maximum wait time + * @return True when successfully waited for async requests and shutdown + * + * @see HttpClientContext.Builder#executor(Executor) + */ + boolean shutdown(long timeout, TimeUnit timeUnit); + /** * Builds the HttpClientContext. * diff --git a/client/src/test/java/io/avaje/http/client/AsyncExecutorTest.java b/client/src/test/java/io/avaje/http/client/AsyncExecutorTest.java new file mode 100644 index 0000000..e591148 --- /dev/null +++ b/client/src/test/java/io/avaje/http/client/AsyncExecutorTest.java @@ -0,0 +1,59 @@ +package io.avaje.http.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.http.HttpResponse; +import java.util.List; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; + +class AsyncExecutorTest extends BaseWebTest { + + final Logger log = LoggerFactory.getLogger(AsyncExecutorTest.class); + + @Test + void context_shutdown() { + + final HttpClientContext clientContext = HttpClientContext.newBuilder() + .baseUrl(baseUrl) + .bodyAdapter(new JacksonBodyAdapter(new ObjectMapper())) + .executor(Executors.newSingleThreadExecutor()) + .build(); + + final CompletableFuture>> future = clientContext.request() + .path("hello").path("stream") + .GET() + .async() + .asLines(); + + final AtomicReference threadName = new AtomicReference<>(); + final AtomicBoolean flag = new AtomicBoolean(); + future.whenComplete((hres, throwable) -> { + flag.set(true); + threadName.set(Thread.currentThread().getName()); + log.info("processing response"); + LockSupport.parkNanos(600_000_000); + assertThat(hres.statusCode()).isEqualTo(200); + List lines = hres.body().collect(Collectors.toList()); + assertThat(lines).hasSize(4); + assertThat(lines.get(0)).contains("{\"id\":1, \"name\":\"one\"}"); + log.info("processing response complete"); + }); + + assertThat(flag).isFalse(); // haven't run the async process yet + assertThat(clientContext.shutdown(2, TimeUnit.SECONDS)).isTrue(); + assertThat(flag).isTrue(); + assertThat(threadName.get()).endsWith("-thread-1"); + } + +} + diff --git a/client/src/test/java/io/avaje/http/client/DHttpClientContextTest.java b/client/src/test/java/io/avaje/http/client/DHttpClientContextTest.java index abd57b9..5d28c84 100644 --- a/client/src/test/java/io/avaje/http/client/DHttpClientContextTest.java +++ b/client/src/test/java/io/avaje/http/client/DHttpClientContextTest.java @@ -10,7 +10,7 @@ class DHttpClientContextTest { - private final DHttpClientContext context = new DHttpClientContext(null, null, null, null, null, null, null, null); + private final DHttpClientContext context = new DHttpClientContext(null, null, null, null, null, null, null, null, null); @Test void create() { diff --git a/client/src/test/java/io/avaje/http/client/DHttpClientRequestTest.java b/client/src/test/java/io/avaje/http/client/DHttpClientRequestTest.java index f771785..5d2fdbb 100644 --- a/client/src/test/java/io/avaje/http/client/DHttpClientRequestTest.java +++ b/client/src/test/java/io/avaje/http/client/DHttpClientRequestTest.java @@ -8,7 +8,7 @@ class DHttpClientRequestTest { - final DHttpClientContext context = new DHttpClientContext(null, null, null, null, null, null, null, null); + final DHttpClientContext context = new DHttpClientContext(null, null, null, null, null, null, null, null, null); @Test void suppressLogging_listenerEvent_expect_suppressedPayloadContent() {