Skip to content

Commit ed1f6ad

Browse files
committed
Add native-image support for RestTemplateBuilder
Closes gh-31888
1 parent a3d4431 commit ed1f6ad

File tree

3 files changed

+113
-19
lines changed

3 files changed

+113
-19
lines changed

spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySupplier.java

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
package org.springframework.boot.web.client;
1818

19-
import java.util.Collections;
20-
import java.util.LinkedHashMap;
21-
import java.util.Map;
19+
import java.util.function.Consumer;
2220
import java.util.function.Supplier;
2321

24-
import org.springframework.beans.BeanUtils;
22+
import org.springframework.aot.hint.RuntimeHints;
23+
import org.springframework.aot.hint.TypeHint.Builder;
24+
import org.springframework.aot.hint.TypeReference;
2525
import org.springframework.http.client.ClientHttpRequestFactory;
26+
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
27+
import org.springframework.http.client.OkHttp3ClientHttpRequestFactory;
2628
import org.springframework.http.client.SimpleClientHttpRequestFactory;
2729
import org.springframework.util.ClassUtils;
2830

@@ -31,30 +33,38 @@
3133
* based on the available implementations on the classpath.
3234
*
3335
* @author Stephane Nicoll
36+
* @author Moritz Halbritter
3437
* @since 2.1.0
3538
*/
3639
public class ClientHttpRequestFactorySupplier implements Supplier<ClientHttpRequestFactory> {
3740

38-
private static final Map<String, String> REQUEST_FACTORY_CANDIDATES;
41+
private static final boolean APACHE_HTTP_CLIENT_PRESENT = ClassUtils.isPresent("org.apache.http.client.HttpClient",
42+
null);
3943

40-
static {
41-
Map<String, String> candidates = new LinkedHashMap<>();
42-
candidates.put("org.apache.http.client.HttpClient",
43-
"org.springframework.http.client.HttpComponentsClientHttpRequestFactory");
44-
candidates.put("okhttp3.OkHttpClient", "org.springframework.http.client.OkHttp3ClientHttpRequestFactory");
45-
REQUEST_FACTORY_CANDIDATES = Collections.unmodifiableMap(candidates);
46-
}
44+
private static final boolean OKHTTP_CLIENT_PRESENT = ClassUtils.isPresent("okhttp3.OkHttpClient", null);
4745

4846
@Override
4947
public ClientHttpRequestFactory get() {
50-
for (Map.Entry<String, String> candidate : REQUEST_FACTORY_CANDIDATES.entrySet()) {
51-
ClassLoader classLoader = getClass().getClassLoader();
52-
if (ClassUtils.isPresent(candidate.getKey(), classLoader)) {
53-
Class<?> factoryClass = ClassUtils.resolveClassName(candidate.getValue(), classLoader);
54-
return (ClientHttpRequestFactory) BeanUtils.instantiateClass(factoryClass);
55-
}
48+
if (APACHE_HTTP_CLIENT_PRESENT) {
49+
return new HttpComponentsClientHttpRequestFactory();
50+
}
51+
if (OKHTTP_CLIENT_PRESENT) {
52+
return new OkHttp3ClientHttpRequestFactory();
5653
}
5754
return new SimpleClientHttpRequestFactory();
5855
}
5956

57+
static class ClientHttpRequestFactorySupplierRuntimeHints {
58+
59+
static void registerHints(RuntimeHints hints, ClassLoader classLoader, Consumer<Builder> callback) {
60+
hints.reflection().registerType(HttpComponentsClientHttpRequestFactory.class, (hint) -> callback
61+
.accept(hint.onReachableType(TypeReference.of("org.apache.http.client.HttpClient"))));
62+
hints.reflection().registerType(OkHttp3ClientHttpRequestFactory.class,
63+
(hint) -> callback.accept(hint.onReachableType(TypeReference.of("okhttp3.OkHttpClient"))));
64+
hints.reflection().registerType(SimpleClientHttpRequestFactory.class, (hint) -> callback
65+
.accept(hint.onReachableType(TypeReference.of(SimpleClientHttpRequestFactory.class))));
66+
}
67+
68+
}
69+
6070
}

spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,19 @@
2929
import java.util.LinkedHashSet;
3030
import java.util.List;
3131
import java.util.Map;
32+
import java.util.Objects;
3233
import java.util.Set;
3334
import java.util.function.Consumer;
3435
import java.util.function.Supplier;
3536

3637
import reactor.netty.http.client.HttpClientRequest;
3738

39+
import org.springframework.aot.hint.ExecutableMode;
40+
import org.springframework.aot.hint.RuntimeHints;
41+
import org.springframework.aot.hint.RuntimeHintsRegistrar;
42+
import org.springframework.aot.hint.TypeReference;
3843
import org.springframework.beans.BeanUtils;
44+
import org.springframework.context.annotation.ImportRuntimeHints;
3945
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
4046
import org.springframework.http.client.ClientHttpRequest;
4147
import org.springframework.http.client.ClientHttpRequestFactory;
@@ -56,7 +62,7 @@
5662
* converters}, {@link #errorHandler(ResponseErrorHandler) error handlers} and
5763
* {@link #uriTemplateHandler(UriTemplateHandler) UriTemplateHandlers}.
5864
* <p>
59-
* By default the built {@link RestTemplate} will attempt to use the most suitable
65+
* By default, the built {@link RestTemplate} will attempt to use the most suitable
6066
* {@link ClientHttpRequestFactory}, call {@link #detectRequestFactory(boolean)
6167
* detectRequestFactory(false)} if you prefer to keep the default. In a typical
6268
* auto-configured Spring Boot application this builder is available as a bean and can be
@@ -71,6 +77,7 @@
7177
* @author Ilya Lukyanovich
7278
* @since 1.4.0
7379
*/
80+
@ImportRuntimeHints(RestTemplateBuilder.RestTemplateBuilderRuntimeHints.class)
7481
public class RestTemplateBuilder {
7582

7683
private final RequestFactoryCustomizer requestFactoryCustomizer;
@@ -789,4 +796,23 @@ private void invoke(ClientHttpRequestFactory requestFactory, Method method, Obje
789796

790797
}
791798

799+
static class RestTemplateBuilderRuntimeHints implements RuntimeHintsRegistrar {
800+
801+
@Override
802+
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
803+
hints.reflection().registerField(Objects.requireNonNull(
804+
ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory")));
805+
ClientHttpRequestFactorySupplier.ClientHttpRequestFactorySupplierRuntimeHints.registerHints(hints,
806+
classLoader, (hint) -> {
807+
hint.withMethod("setConnectTimeout", List.of(TypeReference.of(int.class)),
808+
(method) -> method.withMode(ExecutableMode.INVOKE));
809+
hint.withMethod("setReadTimeout", List.of(TypeReference.of(int.class)),
810+
(method) -> method.withMode(ExecutableMode.INVOKE));
811+
hint.withMethod("setBufferRequestBody", List.of(TypeReference.of(boolean.class)),
812+
(method) -> method.withMode(ExecutableMode.INVOKE));
813+
});
814+
}
815+
816+
}
817+
792818
}

spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,14 @@
3333
import org.mockito.Mock;
3434
import org.mockito.junit.jupiter.MockitoExtension;
3535

36+
import org.springframework.aot.hint.RuntimeHints;
37+
import org.springframework.aot.hint.predicate.ReflectionHintsPredicates;
38+
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
39+
import org.springframework.boot.web.client.RestTemplateBuilder.RestTemplateBuilderRuntimeHints;
3640
import org.springframework.http.HttpHeaders;
3741
import org.springframework.http.HttpMethod;
3842
import org.springframework.http.MediaType;
43+
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
3944
import org.springframework.http.client.BufferingClientHttpRequestFactory;
4045
import org.springframework.http.client.ClientHttpRequest;
4146
import org.springframework.http.client.ClientHttpRequestFactory;
@@ -50,6 +55,7 @@
5055
import org.springframework.http.converter.StringHttpMessageConverter;
5156
import org.springframework.test.util.ReflectionTestUtils;
5257
import org.springframework.test.web.client.MockRestServiceServer;
58+
import org.springframework.util.ReflectionUtils;
5359
import org.springframework.web.client.ResponseErrorHandler;
5460
import org.springframework.web.client.RestTemplate;
5561
import org.springframework.web.util.UriTemplateHandler;
@@ -585,6 +591,58 @@ void unwrappingDoesNotAffectRequestFactoryThatIsSetOnTheBuiltTemplate() {
585591
assertThat(template.getRequestFactory()).isInstanceOf(BufferingClientHttpRequestFactory.class);
586592
}
587593

594+
@Test
595+
void shouldRegisterHints() {
596+
RuntimeHints hints = new RuntimeHints();
597+
new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader());
598+
ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection();
599+
assertThat(reflection
600+
.onField(ReflectionUtils.findField(AbstractClientHttpRequestFactoryWrapper.class, "requestFactory")))
601+
.accepts(hints);
602+
}
603+
604+
@Test
605+
void shouldRegisterHttpComponentHints() {
606+
RuntimeHints hints = new RuntimeHints();
607+
new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader());
608+
ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection();
609+
assertThat(reflection.onMethod(ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class,
610+
"setConnectTimeout", int.class))).accepts(hints);
611+
assertThat(reflection.onMethod(
612+
ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class, "setReadTimeout", int.class)))
613+
.accepts(hints);
614+
assertThat(reflection.onMethod(ReflectionUtils.findMethod(HttpComponentsClientHttpRequestFactory.class,
615+
"setBufferRequestBody", boolean.class))).accepts(hints);
616+
}
617+
618+
@Test
619+
void shouldRegisterOkHttpHints() {
620+
RuntimeHints hints = new RuntimeHints();
621+
new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader());
622+
ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection();
623+
assertThat(reflection.onMethod(
624+
ReflectionUtils.findMethod(OkHttp3ClientHttpRequestFactory.class, "setConnectTimeout", int.class)))
625+
.accepts(hints);
626+
assertThat(reflection.onMethod(
627+
ReflectionUtils.findMethod(OkHttp3ClientHttpRequestFactory.class, "setReadTimeout", int.class)))
628+
.accepts(hints);
629+
}
630+
631+
@Test
632+
void shouldRegisterSimpleHttpHints() {
633+
RuntimeHints hints = new RuntimeHints();
634+
new RestTemplateBuilderRuntimeHints().registerHints(hints, getClass().getClassLoader());
635+
ReflectionHintsPredicates reflection = RuntimeHintsPredicates.reflection();
636+
assertThat(reflection.onMethod(
637+
ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, "setConnectTimeout", int.class)))
638+
.accepts(hints);
639+
assertThat(reflection.onMethod(
640+
ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class, "setReadTimeout", int.class)))
641+
.accepts(hints);
642+
assertThat(reflection.onMethod(ReflectionUtils.findMethod(SimpleClientHttpRequestFactory.class,
643+
"setBufferRequestBody", boolean.class))).accepts(hints);
644+
}
645+
588646
private ClientHttpRequest createRequest(RestTemplate template) {
589647
return ReflectionTestUtils.invokeMethod(template, "createRequest", URI.create("http://localhost"),
590648
HttpMethod.GET);

0 commit comments

Comments
 (0)