Skip to content

Commit f97a2a7

Browse files
committed
Fix race condition in RetryInterceptor
Signed-off-by: andreadimaio <[email protected]>
1 parent d2e9995 commit f97a2a7

File tree

2 files changed

+67
-148
lines changed

2 files changed

+67
-148
lines changed

modules/watsonx-ai-core/src/main/java/com/ibm/watsonx/ai/core/http/interceptors/RetryInterceptor.java

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55
package com.ibm.watsonx.ai.core.http.interceptors;
66

7+
import static com.ibm.watsonx.ai.core.http.BaseHttpClient.REQUEST_ID_HEADER;
78
import static java.util.Objects.isNull;
89
import static java.util.Objects.requireNonNull;
910
import static java.util.Objects.requireNonNullElse;
@@ -34,11 +35,15 @@ public final class RetryInterceptor implements SyncHttpInterceptor, AsyncHttpInt
3435

3536
private static final Logger logger = LoggerFactory.getLogger(RetryInterceptor.class);
3637

38+
private record RetryOn(Class<? extends Throwable> clazz, Optional<Predicate<Throwable>> predicate) {}
39+
40+
private final Duration retryInterval;
41+
private final List<RetryOn> retryOn;
42+
private final boolean exponentialBackoff;
43+
private final Integer maxRetries;
44+
3745
/**
3846
* Checks whether a {@link WatsonxException} is retryable due to an expired authentication token.
39-
* <p>
40-
* This condition is met if the HTTP status code is 401 and at least one error in the exception's details has the code
41-
* {@code AUTHENTICATION_TOKEN_EXPIRED}.
4247
*
4348
* @param maxRetries the maximum number of retry attempts
4449
* @return a configured {@link RetryInterceptor} instance that handles token expiration retries
@@ -59,28 +64,19 @@ public static RetryInterceptor onTokenExpired(int maxRetries) {
5964
).build();
6065
}
6166

62-
private record RetryOn(Class<? extends Throwable> clazz, Optional<Predicate<Throwable>> predicate) {}
63-
64-
private final Duration retryInterval;
65-
private final List<RetryOn> retryOn;
66-
private final boolean exponentialBackoff;
67-
private Integer maxRetries;
68-
private Duration timeout;
69-
7067
/**
7168
* Creates a new {@code RetryInterceptor} using the provided builder.
7269
*
7370
* @param builder the builder instance
7471
*/
7572
public RetryInterceptor(Builder builder) {
7673
requireNonNull(builder);
77-
this.retryInterval = requireNonNullElse(builder.retryInterval, Duration.ofMillis(0));
78-
this.timeout = this.retryInterval;
79-
this.maxRetries = requireNonNullElse(builder.maxRetries, 1);
80-
this.retryOn = builder.retryOn;
81-
this.exponentialBackoff = builder.exponentialBackoff;
82-
if (isNull(retryOn) || retryOn.isEmpty())
83-
throw new RuntimeException("At least one exception must be specified");
74+
retryInterval = requireNonNullElse(builder.retryInterval, Duration.ofMillis(0));
75+
maxRetries = requireNonNullElse(builder.maxRetries, 1);
76+
retryOn = requireNonNull(builder.retryOn, "At least one exception must be specified");
77+
exponentialBackoff = builder.exponentialBackoff;
78+
if (exponentialBackoff && !retryInterval.isPositive())
79+
throw new IllegalArgumentException("Retry interval must be positive when exponential backoff is enabled");
8480
}
8581

8682
@Override
@@ -89,15 +85,23 @@ public <T> HttpResponse<T> intercept(HttpRequest request, BodyHandler<T> bodyHan
8985

9086
Throwable exception = null;
9187

88+
String requestId = request.headers()
89+
.firstValue(REQUEST_ID_HEADER)
90+
.orElseThrow(); // This should never happen. The SyncHttpClient and AsyncHttpClient add this header if it is not present.
91+
92+
Duration timeout = Duration.from(retryInterval);
93+
9294
for (int attempt = 0; attempt <= maxRetries; attempt++) {
9395

9496
try {
9597

96-
if (attempt > 0)
97-
Thread.sleep(timeout.toMillis());
98+
if (attempt > 0 && timeout.isPositive()) {
99+
logger.debug("Retry request \"{}\" after {} ms", requestId, timeout.toMillis());
100+
Thread.sleep(timeout);
101+
}
98102

99103
var res = chain.proceed(request, bodyHandler);
100-
this.timeout = this.retryInterval;
104+
timeout = Duration.from(retryInterval);
101105
return res;
102106

103107
} catch (Exception e) {
@@ -117,54 +121,39 @@ public <T> HttpResponse<T> intercept(HttpRequest request, BodyHandler<T> bodyHan
117121
timeout = timeout.multipliedBy(2);
118122
}
119123
if (attempt > 0) {
120-
logger.debug("Retrying request ({}/{}) after failure: {}", attempt, maxRetries,
124+
logger.debug("Retrying request \"{}\" ({}/{}) after failure: {}", requestId, attempt, maxRetries,
121125
exception.getMessage());
122126
}
123127
chain.resetToIndex(index + 1);
124128
continue;
125129
}
126130

127-
this.timeout = this.retryInterval;
131+
timeout = Duration.from(retryInterval);
128132
throw e;
129133
}
130134
}
131135

132-
this.timeout = this.retryInterval;
133-
logger.debug("Max retries reached");
134-
135-
throw new RuntimeException("Max retries reached", isNull(exception) ? new Exception() : exception);
136+
timeout = Duration.from(retryInterval);
137+
throw new RuntimeException("Max retries reached for request [%s]".formatted(requestId), isNull(exception) ? new Exception() : exception);
136138
}
137139

138140
@Override
139141
public <T> CompletableFuture<HttpResponse<T>> intercept(HttpRequest request, BodyHandler<T> bodyHandler,
140142
Executor executor, int index, AsyncChain chain) {
141-
return executeWithRetry(request, bodyHandler, executor, index, 0, chain);
143+
return executeWithRetry(request, bodyHandler, executor, index, 0, Duration.from(retryInterval), chain);
142144
}
143145

144-
/**
145-
* The current timeout interval.
146-
*
147-
* @return the timeout duration
148-
*/
149-
public Duration getTimeout() {
150-
return timeout;
151-
}
152-
153-
/**
154-
* Returns a new {@link Builder} instance.
155-
*
156-
* @return {@link Builder} instance.
157-
*/
158-
public static Builder builder() {
159-
return new Builder();
160-
}
161146

162147
private <T> CompletableFuture<HttpResponse<T>> executeWithRetry(HttpRequest request, BodyHandler<T> bodyHandler,
163-
Executor executor, int index, int attempt, AsyncChain chain) {
148+
Executor executor, int index, int attempt, Duration timeout, AsyncChain chain) {
164149

165150
return chain.proceed(request, bodyHandler, executor)
166151
.exceptionallyCompose(throwable -> {
167152

153+
String requestId = request.headers()
154+
.firstValue(REQUEST_ID_HEADER)
155+
.orElseThrow(); // This should never happen. The SyncHttpClient and AsyncHttpClient add this header if it is not present.
156+
168157
Throwable cause = throwable.getCause() != null ? throwable.getCause() : throwable;
169158

170159
var shouldRetry =
@@ -178,27 +167,36 @@ private <T> CompletableFuture<HttpResponse<T>> executeWithRetry(HttpRequest requ
178167

179168
if (!shouldRetry || attempt >= maxRetries) {
180169
CompletableFuture<HttpResponse<T>> failed = new CompletableFuture<>();
181-
logger.debug("Max retries reached");
170+
logger.debug("Max retries reached for request \"{}\"", requestId);
182171
failed.completeExceptionally(cause);
183-
this.timeout = this.retryInterval;
184172
return failed;
185173
}
186174

187-
if (exponentialBackoff && attempt > 0)
188-
timeout = timeout.multipliedBy(2);
175+
Duration nextTimeout = exponentialBackoff ? timeout.multipliedBy(2) : timeout;
189176

190-
logger.debug("Retrying request ({}/{}) after failure: {}", attempt + 1, maxRetries, cause.getMessage());
177+
if (timeout.isPositive())
178+
logger.debug("Retry request \"{}\" after {} ms", requestId, nextTimeout.toMillis());
191179

192180
return CompletableFuture.supplyAsync(
193181
() -> {
182+
logger.debug("Retrying request \"{}\" ({}/{}) after failure: {}", requestId, attempt + 1, maxRetries, cause.getMessage());
194183
chain.resetToIndex(index + 1);
195-
return executeWithRetry(request, bodyHandler, executor, index, attempt + 1, chain);
184+
return executeWithRetry(request, bodyHandler, executor, index, attempt + 1, nextTimeout, chain);
196185
},
197-
CompletableFuture.delayedExecutor(timeout.toMillis(), TimeUnit.MILLISECONDS, executor)
186+
CompletableFuture.delayedExecutor(nextTimeout.toMillis(), TimeUnit.MILLISECONDS, executor)
198187
).thenCompose(Function.identity());
199188
});
200189
}
201190

191+
/**
192+
* Returns a new {@link Builder} instance.
193+
*
194+
* @return {@link Builder} instance.
195+
*/
196+
public static Builder builder() {
197+
return new Builder();
198+
}
199+
202200
/**
203201
* Builder for {@link RetryInterceptor}.
204202
*/

modules/watsonx-ai-core/src/test/java/com/ibm/watsonx/ai/core/RetryInterceptorTest.java

Lines changed: 16 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ void setup() {
7878
when(httpRequest.headers()).thenReturn(HttpHeaders.of(Map.of(), (k, v) -> true));
7979
}
8080

81+
@Test
82+
void test_exponential_backoff_without_retry_interval() {
83+
84+
assertThrows(IllegalArgumentException.class, () -> RetryInterceptor.builder()
85+
.exponentialBackoff(true)
86+
.retryOn(NullPointerException.class)
87+
.build(), "Retry interval must be positive when exponential backoff is enabled");
88+
}
89+
8190
@Nested
8291
public class Sync {
8392

@@ -318,44 +327,7 @@ void retry_with_tool_exception() throws Exception {
318327
verify(mockInterceptor, times(2)).intercept(any(), eq(bodyHandler), anyInt(), any());
319328
}
320329

321-
@Test
322-
void retry_with_exponential_backoff_fail_retries() throws Exception {
323-
Duration timeout = Duration.ofMillis(10);
324-
RetryInterceptor retryInterceptor = RetryInterceptor.builder()
325-
.maxRetries(3)
326-
.retryInterval(timeout)
327-
.exponentialBackoff(true)
328-
.retryOn(NullPointerException.class)
329-
.build();
330-
331-
SyncHttpInterceptor mockInterceptor = new SyncHttpInterceptor() {
332-
int numCalled = 0;
333330

334-
@Override
335-
public <T> HttpResponse<T> intercept(HttpRequest request, BodyHandler<T> bodyHandler, int index, Chain chain)
336-
throws WatsonxException, IOException, InterruptedException {
337-
if (this.numCalled > 0)
338-
assertEquals((long) (timeout.toMillis() * Math.pow(2, this.numCalled - 1)),
339-
retryInterceptor.getTimeout().toMillis());
340-
this.numCalled++;
341-
return chain.proceed(request, bodyHandler);
342-
}
343-
};
344-
345-
SyncHttpClient client = SyncHttpClient.builder()
346-
.httpClient(httpClient)
347-
.interceptor(retryInterceptor)
348-
.interceptor(mockInterceptor)
349-
.build();
350-
351-
when(httpClient.send(any(), eq(bodyHandler)))
352-
.thenThrow(NullPointerException.class);
353-
354-
var ex = assertThrows(RuntimeException.class, () -> client.send(httpRequest, bodyHandler));
355-
assertEquals(NullPointerException.class, ex.getCause().getClass());
356-
verify(httpClient, times(4)).send(any(), eq(bodyHandler));
357-
assertEquals(retryInterceptor.getTimeout().toMillis(), timeout.toMillis());
358-
}
359331

360332
@Test
361333
void retry_with_exponential_backoff_succeed_after_retry() throws Exception {
@@ -368,16 +340,9 @@ void retry_with_exponential_backoff_succeed_after_retry() throws Exception {
368340
.build();
369341

370342
SyncHttpInterceptor mockInterceptor = new SyncHttpInterceptor() {
371-
int numCalled = 0;
372-
373343
@Override
374-
public <T> HttpResponse<T> intercept(HttpRequest request, BodyHandler<T> bodyHandler, int index,
375-
Chain chain)
344+
public <T> HttpResponse<T> intercept(HttpRequest request, BodyHandler<T> bodyHandler, int index, Chain chain)
376345
throws WatsonxException, IOException, InterruptedException {
377-
if (this.numCalled > 0)
378-
assertEquals((long) (timeout.toMillis() * Math.pow(2, this.numCalled - 1)),
379-
retryInterceptor.getTimeout().toMillis());
380-
this.numCalled++;
381346
return chain.proceed(request, bodyHandler);
382347
}
383348
};
@@ -389,15 +354,16 @@ public <T> HttpResponse<T> intercept(HttpRequest request, BodyHandler<T> bodyHan
389354
.build();
390355

391356
when(httpClient.send(any(), eq(bodyHandler)))
357+
.thenThrow(NullPointerException.class)
358+
.thenThrow(NullPointerException.class)
392359
.thenThrow(NullPointerException.class)
393360
.thenReturn(httpResponse);
394361

395362
when(httpResponse.statusCode())
396363
.thenReturn(200);
397364

398365
client.send(httpRequest, bodyHandler);
399-
verify(httpClient, times(2)).send(any(), eq(bodyHandler));
400-
assertEquals(retryInterceptor.getTimeout().toMillis(), timeout.toMillis());
366+
verify(httpClient, times(4)).send(any(), eq(bodyHandler));
401367
}
402368
}
403369

@@ -586,46 +552,6 @@ void retry_with_watsonx_exception() throws Exception {
586552
verify(mockInterceptor, times(2)).intercept(any(), eq(bodyHandler), any(), anyInt(), any());
587553
}
588554

589-
@Test
590-
@SuppressWarnings("unchecked")
591-
void retry_with_exponential_backoff_fail_retries() throws Exception {
592-
Duration timeout = Duration.ofMillis(10);
593-
RetryInterceptor retryInterceptor = RetryInterceptor.builder()
594-
.maxRetries(3)
595-
.retryInterval(timeout)
596-
.exponentialBackoff(true)
597-
.retryOn(NullPointerException.class)
598-
.build();
599-
600-
AsyncHttpInterceptor mockInterceptor = new AsyncHttpInterceptor() {
601-
int numCalled = 0;
602-
603-
@Override
604-
public <T> CompletableFuture<HttpResponse<T>> intercept(HttpRequest request, BodyHandler<T> bodyHandler,
605-
Executor executor, int index, AsyncChain chain) {
606-
if (this.numCalled > 0)
607-
assertEquals((long) (timeout.toMillis() * Math.pow(2, this.numCalled - 1)),
608-
retryInterceptor.getTimeout().toMillis());
609-
this.numCalled++;
610-
return chain.proceed(request, bodyHandler, executor);
611-
}
612-
};
613-
614-
AsyncHttpClient client = AsyncHttpClient.builder()
615-
.httpClient(httpClient)
616-
.interceptor(retryInterceptor)
617-
.interceptor(mockInterceptor)
618-
.build();
619-
620-
when(httpClient.sendAsync(any(), any(BodyHandler.class)))
621-
.thenReturn(CompletableFuture.failedFuture(new NullPointerException()));
622-
623-
var ex = assertThrows(RuntimeException.class, () -> client.send(httpRequest, bodyHandler).join());
624-
assertEquals(NullPointerException.class, ex.getCause().getClass());
625-
verify(httpClient, times(4)).sendAsync(any(), any(BodyHandler.class));
626-
assertEquals(retryInterceptor.getTimeout().toMillis(), timeout.toMillis());
627-
}
628-
629555
@Test
630556
@SuppressWarnings("unchecked")
631557
void retry_with_exponential_backoff_succeed_after_retry() throws Exception {
@@ -638,15 +564,9 @@ void retry_with_exponential_backoff_succeed_after_retry() throws Exception {
638564
.build();
639565

640566
AsyncHttpInterceptor mockInterceptor = new AsyncHttpInterceptor() {
641-
int numCalled = 0;
642-
643567
@Override
644568
public <T> CompletableFuture<HttpResponse<T>> intercept(HttpRequest request, BodyHandler<T> bodyHandler,
645569
Executor executor, int index, AsyncChain chain) {
646-
if (this.numCalled > 0)
647-
assertEquals((long) (timeout.toMillis() * Math.pow(2, this.numCalled - 1)),
648-
retryInterceptor.getTimeout().toMillis());
649-
this.numCalled++;
650570
return chain.proceed(request, bodyHandler, executor);
651571
}
652572
};
@@ -658,12 +578,13 @@ public <T> CompletableFuture<HttpResponse<T>> intercept(HttpRequest request, Bod
658578
.build();
659579

660580
when(httpClient.sendAsync(any(), any(BodyHandler.class)))
581+
.thenReturn(failedFuture(new NullPointerException()))
582+
.thenReturn(failedFuture(new NullPointerException()))
661583
.thenReturn(failedFuture(new NullPointerException()))
662584
.thenReturn(completedFuture(httpResponse));
663585

664586
client.send(httpRequest, bodyHandler).join();
665-
verify(httpClient, times(2)).sendAsync(any(), any(BodyHandler.class));
666-
assertEquals(retryInterceptor.getTimeout().toMillis(), timeout.toMillis());
587+
verify(httpClient, times(4)).sendAsync(any(), any(BodyHandler.class));
667588
}
668589
}
669590

0 commit comments

Comments
 (0)