diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java index e28f27e2f7c..d92ee16b4cb 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java @@ -49,6 +49,7 @@ import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.web.client.RestOperations; import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; @@ -283,6 +284,8 @@ public class JwtConfigurer { private Converter jwtAuthenticationConverter; + private RestOperations restOperations; + JwtConfigurer(ApplicationContext context) { this.context = context; } @@ -299,7 +302,15 @@ public JwtConfigurer decoder(JwtDecoder decoder) { } public JwtConfigurer jwkSetUri(String uri) { - this.decoder = withJwkSetUri(uri).build(); + final NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(uri); + this.decoder = restOperations == null + ? builder.build() + : builder.restOperations(restOperations).build(); + return this; + } + + public JwtConfigurer restOperations(RestOperations restOperations) { + this.restOperations = restOperations; return this; } diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParser.java index 5068a74b456..be1d2a07ec8 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParser.java @@ -50,6 +50,7 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; +import org.springframework.web.client.RestOperations; /** * A {@link BeanDefinitionParser} for <http>'s <oauth2-resource-server> element. @@ -194,6 +195,7 @@ BeanMetadataElement getEntryPoint(Element element) { final class JwtBeanDefinitionParser implements BeanDefinitionParser { static final String DECODER_REF = "decoder-ref"; static final String JWK_SET_URI = "jwk-set-uri"; + static final String REST_OPERATIONS_REF = "rest-operations-ref"; static final String JWT_AUTHENTICATION_CONVERTER_REF = "jwt-authentication-converter-ref"; static final String JWT_AUTHENTICATION_CONVERTER = "jwtAuthenticationConverter"; @@ -228,6 +230,12 @@ Object getDecoder(Element element) { BeanDefinitionBuilder builder = BeanDefinitionBuilder .rootBeanDefinition(NimbusJwtDecoderJwkSetUriFactoryBean.class); builder.addConstructorArgValue(element.getAttribute(JWK_SET_URI)); + final String restOperationsRef = element.getAttribute(REST_OPERATIONS_REF); + if (StringUtils.isEmpty(restOperationsRef)) { + builder.addConstructorArgValue(null); + } else { + builder.addConstructorArgReference(restOperationsRef); + } return builder.getBeanDefinition(); } @@ -322,14 +330,19 @@ public AuthenticationManager resolve(HttpServletRequest context) { final class NimbusJwtDecoderJwkSetUriFactoryBean implements FactoryBean { private final String jwkSetUri; + private final RestOperations restOperations; - NimbusJwtDecoderJwkSetUriFactoryBean(String jwkSetUri) { + NimbusJwtDecoderJwkSetUriFactoryBean(String jwkSetUri, RestOperations restOperations) { this.jwkSetUri = jwkSetUri; + this.restOperations = restOperations; } @Override public JwtDecoder getObject() { - return NimbusJwtDecoder.withJwkSetUri(this.jwkSetUri).build(); + final NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = NimbusJwtDecoder.withJwkSetUri(this.jwkSetUri); + return restOperations == null + ? builder.build() + : builder.restOperations(restOperations).build(); } @Override diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 91866bff42f..81792275ca5 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -58,6 +58,7 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; @@ -174,9 +175,12 @@ import org.springframework.web.cors.reactive.CorsProcessor; import org.springframework.web.cors.reactive.CorsWebFilter; import org.springframework.web.cors.reactive.DefaultCorsProcessor; +import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.match; @@ -1874,6 +1878,8 @@ public class JwtSpec { private ReactiveJwtDecoder jwtDecoder; private Converter> jwtAuthenticationConverter = new ReactiveJwtAuthenticationConverterAdapter(new JwtAuthenticationConverter()); + private WebClient webClient; + private String jwkSetUri; /** * Configures the {@link ReactiveAuthenticationManager} to use @@ -1929,10 +1935,27 @@ public JwtSpec publicKey(RSAPublicKey publicKey) { * @return the {@code JwtSpec} for additional configuration */ public JwtSpec jwkSetUri(String jwkSetUri) { - this.jwtDecoder = new NimbusReactiveJwtDecoder(jwkSetUri); + this.jwkSetUri = jwkSetUri; + this.jwtDecoder = createDecoder(); return this; } + public JwtSpec webClient(WebClient webClient) { + this.webClient = webClient; + this.jwtDecoder = createDecoder(); + return this; + } + + private ReactiveJwtDecoder createDecoder() { + if (jwkSetUri != null) { + return webClient == null + ? new NimbusReactiveJwtDecoder(jwkSetUri) + : new NimbusReactiveJwtDecoder(jwkSetUri, webClient); + } else { + return null; + } + } + public OAuth2ResourceServerSpec and() { return OAuth2ResourceServerSpec.this; } diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerJwtDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerJwtDsl.kt index 0ba0501c66a..64bcf4d1db7 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerJwtDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerJwtDsl.kt @@ -22,6 +22,7 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.core.Authentication import org.springframework.security.oauth2.jwt.Jwt import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder +import org.springframework.web.reactive.function.client.WebClient import reactor.core.publisher.Mono import java.security.interfaces.RSAPublicKey @@ -44,6 +45,7 @@ class ServerJwtDsl { private var _jwtDecoder: ReactiveJwtDecoder? = null private var _publicKey: RSAPublicKey? = null private var _jwkSetUri: String? = null + private var _webClient: WebClient? = null var authenticationManager: ReactiveAuthenticationManager? = null var jwtAuthenticationConverter: Converter>? = null @@ -69,14 +71,20 @@ class ServerJwtDsl { _jwtDecoder = null _publicKey = null } + var webClient: WebClient? + get() = _webClient + set(value) { + _webClient = value + } internal fun get(): (ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec) -> Unit { return { jwt -> authenticationManager?.also { jwt.authenticationManager(authenticationManager) } jwtAuthenticationConverter?.also { jwt.jwtAuthenticationConverter(jwtAuthenticationConverter) } - jwtDecoder?.also { jwt.jwtDecoder(jwtDecoder) } publicKey?.also { jwt.publicKey(publicKey) } + webClient?.also { jwt.webClient(webClient) } jwkSetUri?.also { jwt.jwkSetUri(jwkSetUri) } + jwtDecoder?.also { jwt.jwtDecoder(jwtDecoder) } } } } diff --git a/config/src/main/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDsl.kt index e8d8008a974..038e34cbeee 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDsl.kt @@ -22,6 +22,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.configurers.oauth2.server.resource.OAuth2ResourceServerConfigurer import org.springframework.security.oauth2.jwt.Jwt import org.springframework.security.oauth2.jwt.JwtDecoder +import org.springframework.web.client.RestOperations /** * A Kotlin DSL to configure JWT Resource Server Support using idiomatic Kotlin code. @@ -38,6 +39,7 @@ import org.springframework.security.oauth2.jwt.JwtDecoder class JwtDsl { private var _jwtDecoder: JwtDecoder? = null private var _jwkSetUri: String? = null + private var _restOperations: RestOperations? = null var jwtAuthenticationConverter: Converter? = null var jwtDecoder: JwtDecoder? @@ -53,10 +55,17 @@ class JwtDsl { _jwtDecoder = null } + var restOperations: RestOperations? + get() = _restOperations + set(value) { + _restOperations = value + } + internal fun get(): (OAuth2ResourceServerConfigurer.JwtConfigurer) -> Unit { return { jwt -> jwtAuthenticationConverter?.also { jwt.jwtAuthenticationConverter(jwtAuthenticationConverter) } jwtDecoder?.also { jwt.decoder(jwtDecoder) } + restOperations?.also { jwt.restOperations(restOperations) } jwkSetUri?.also { jwt.jwkSetUri(jwkSetUri) } } } diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.4.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-5.4.xsd index 436820de820..d391a9d33e9 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.4.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.4.xsd @@ -124,7 +124,7 @@ - + @@ -408,7 +408,7 @@ - + @@ -488,7 +488,7 @@ - + @@ -541,7 +541,7 @@ - + @@ -785,13 +785,13 @@ - - - - - - - + + + + + + + @@ -1240,7 +1240,7 @@ - + @@ -1265,7 +1265,7 @@ - + @@ -1322,7 +1322,7 @@ - + @@ -1369,7 +1369,7 @@ - + @@ -1834,6 +1834,12 @@ + + + Reference to a RestOperations bean + + + Reference to a JwtDecoder @@ -1882,7 +1888,7 @@ - + Sets up an attribute exchange configuration to request specified attributes from the @@ -2085,7 +2091,7 @@ - + @@ -2101,7 +2107,7 @@ - + @@ -2157,7 +2163,7 @@ - + @@ -2204,7 +2210,7 @@ - + @@ -2302,7 +2308,7 @@ - + @@ -2335,8 +2341,8 @@ - - + + @@ -2353,7 +2359,7 @@ - + @@ -2490,7 +2496,7 @@ - + @@ -2542,7 +2548,7 @@ - + @@ -3185,4 +3191,4 @@ - \ No newline at end of file + diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java index 15186e0b7ef..950e5ed70bb 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java @@ -47,7 +47,6 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mockito; import org.w3c.dom.Element; import org.springframework.beans.factory.DisposableBean; @@ -97,11 +96,12 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.when; import static org.springframework.security.config.http.JwtBeanDefinitionParser.DECODER_REF; import static org.springframework.security.config.http.JwtBeanDefinitionParser.JWK_SET_URI; import static org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.AUTHENTICATION_MANAGER_RESOLVER_REF; @@ -158,6 +158,23 @@ public void getWhenUsingJwkSetUriThenAcceptsRequest() throws Exception { .andExpect(status().isNotFound()); } + @Test + public void jwkSetUriWithRestOperations() throws Exception { + spring.configLocations(xml("WebServer"), xml("JwkSetUriRestOperations")).autowire(); + + RestOperations restOperations = spring.getContext().getBean(RestOperations.class); + when(restOperations.exchange(any(), eq(String.class))).thenThrow(new IllegalStateException("custom rest-operations")); + + assertThatThrownBy(() -> { + mvc.perform(get("/") + .header("Authorization", "Bearer " + token("ValidNoScopes"))); + }).hasRootCauseInstanceOf(IllegalStateException.class) + .hasRootCauseMessage("custom rest-operations"); + + verify(restOperations).exchange(any(), eq(String.class)); + verifyNoMoreInteractions(restOperations); + } + @Test public void getWhenExpiredBearerTokenThenInvalidToken() throws Exception { @@ -531,7 +548,7 @@ public void requestWhenBearerTokenResolverAllowsQueryParameterThenEitherHeaderOr this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInQuery")).autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - Mockito.when(decoder.decode(anyString())).thenReturn(jwt().build()); + when(decoder.decode(anyString())).thenReturn(jwt().build()); this.mvc.perform(get("/authenticated") .header("Authorization", "Bearer token")) @@ -616,7 +633,7 @@ public void requestWhenRealmNameConfiguredThenUsesOnUnauthenticated() this.spring.configLocations(xml("MockJwtDecoder"), xml("AuthenticationEntryPoint")).autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - Mockito.when(decoder.decode(anyString())).thenThrow(JwtException.class); + when(decoder.decode(anyString())).thenThrow(JwtException.class); this.mvc.perform(get("/authenticated") .header("Authorization", "Bearer invalid_token")) @@ -631,7 +648,7 @@ public void requestWhenRealmNameConfiguredThenUsesOnAccessDenied() this.spring.configLocations(xml("MockJwtDecoder"), xml("AccessDeniedHandler")).autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - Mockito.when(decoder.decode(anyString())).thenReturn(jwt().build()); + when(decoder.decode(anyString())).thenReturn(jwt().build()); this.mvc.perform(get("/authenticated") .header("Authorization", "Bearer insufficiently_scoped")) @@ -703,7 +720,7 @@ public void requestWhenJwtAuthenticationConverterThenUsed() .thenReturn(new JwtAuthenticationToken(jwt().build(), Collections.emptyList())); JwtDecoder jwtDecoder = this.spring.getContext().getBean(JwtDecoder.class); - Mockito.when(jwtDecoder.decode(anyString())).thenReturn(jwt().build()); + when(jwtDecoder.decode(anyString())).thenReturn(jwt().build()); this.mvc.perform(get("/") .header("Authorization", "Bearer token")) @@ -1205,7 +1222,7 @@ private void mockRestOperations(String response) { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); ResponseEntity entity = new ResponseEntity<>(response, headers, HttpStatus.OK); - Mockito.when(rest.exchange(any(RequestEntity.class), eq(String.class))) + when(rest.exchange(any(RequestEntity.class), eq(String.class))) .thenReturn(entity); } diff --git a/config/src/test/java/org/springframework/security/config/util/AlwaysRethrowAuthenticationEntryPoint.java b/config/src/test/java/org/springframework/security/config/util/AlwaysRethrowAuthenticationEntryPoint.java new file mode 100644 index 00000000000..bfe57c6da05 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/util/AlwaysRethrowAuthenticationEntryPoint.java @@ -0,0 +1,29 @@ +/* + * Copyright 2009-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.util; + +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.web.AuthenticationEntryPoint; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +public class AlwaysRethrowAuthenticationEntryPoint implements AuthenticationEntryPoint { + @Override + public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) { + throw authException; + } +} diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerJwtDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerJwtDslTests.kt index ddb33b5323a..724bdbeb660 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerJwtDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerJwtDslTests.kt @@ -18,15 +18,20 @@ package org.springframework.security.config.web.server import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockWebServer +import org.apache.commons.lang3.exception.ExceptionUtils import org.assertj.core.api.Assertions.assertThat import org.junit.Rule import org.junit.Test +import org.mockito.ArgumentMatchers import org.mockito.Mockito.* import org.springframework.beans.factory.annotation.Autowired import org.springframework.context.ApplicationContext import org.springframework.context.annotation.Bean +import org.springframework.core.annotation.Order import org.springframework.core.convert.converter.Converter import org.springframework.http.HttpHeaders +import org.springframework.http.HttpStatus +import org.springframework.http.MediaType import org.springframework.security.authentication.AbstractAuthenticationToken import org.springframework.security.authentication.TestingAuthenticationToken import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity @@ -39,6 +44,9 @@ import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.web.bind.annotation.GetMapping import org.springframework.web.bind.annotation.RestController import org.springframework.web.reactive.config.EnableWebFlux +import org.springframework.web.reactive.function.client.WebClient +import org.springframework.web.server.ServerWebExchange +import org.springframework.web.server.WebExceptionHandler import reactor.core.publisher.Mono import java.math.BigInteger import java.security.KeyFactory @@ -203,6 +211,116 @@ class ServerJwtDslTests { } } + @Test + fun `jwt when using custom JWK Set URI than webClient set`() { + this.spring.register(CustomJwkSetUriThanWebClientConfig::class.java).autowire() + + val result = this.client.get() + .uri("/") + .headers { headers: HttpHeaders -> headers.setBearerAuth(messageReadToken) } + .exchange() + .returnResult(String::class.java) + .responseBodyContent?.toString(Charsets.UTF_8) ?: "invalid" + assertThat(result) + .contains("method get(...) invoked") + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomJwkSetUriThanWebClientConfig { + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + authorizeExchange { + authorize(anyExchange, authenticated) + } + oauth2ResourceServer { + jwt { + jwkSetUri = "http://foo/.well-known/jwks.json" + webClient = alwaysFailWebClient() + } + } + } + } + + @Bean + open fun globalErrorHandler() = GlobalErrorHandler() + } + + @Test + fun `jwt when using custom webClient than JWK Set URI set`() { + this.spring.register(CustomWebClientThanJwkSetUriConfig::class.java).autowire() + + val result = this.client.get() + .uri("/") + .headers { headers: HttpHeaders -> headers.setBearerAuth(messageReadToken) } + .exchange() + .returnResult(String::class.java) + .responseBodyContent?.toString(Charsets.UTF_8) ?: "invalid" + assertThat(result) + .contains("method get(...) invoked") + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomWebClientThanJwkSetUriConfig { + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + authorizeExchange { + authorize(anyExchange, authenticated) + } + oauth2ResourceServer { + jwt { + webClient = alwaysFailWebClient() + jwkSetUri = "http://foo/.well-known/jwks.json" + } + } + } + } + + @Bean + open fun globalErrorHandler() = GlobalErrorHandler() + } + + @Test + fun `jwt webClient and JWK Set URI replaced with jwtDecoder`() { + this.spring.register(CustomWebClientAndJwkSetUriReplacedWithJwtDecoderConfig::class.java).autowire() + + val result = this.client.get() + .uri("/") + .headers { headers: HttpHeaders -> headers.setBearerAuth(messageReadToken) } + .exchange() + .returnResult(String::class.java) + .responseBodyContent?.toString(Charsets.UTF_8) ?: "invalid" + assertThat(result) + .contains("replaced jwtDecoder") + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomWebClientAndJwkSetUriReplacedWithJwtDecoderConfig { + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + authorizeExchange { + authorize(anyExchange, authenticated) + } + oauth2ResourceServer { + jwt { + webClient = alwaysFailWebClient() + jwkSetUri = "http://foo/.well-known/jwks.json" + jwtDecoder = mock(ReactiveJwtDecoder::class.java).also { + `when`(it.decode(ArgumentMatchers.anyString())).thenReturn(Mono.error(java.lang.IllegalStateException("replaced jwtDecoder"))) + } + } + } + } + } + + @Bean + open fun globalErrorHandler() = GlobalErrorHandler() + } @Test fun `opaque token when custom JWT authentication converter then converter used`() { @@ -266,5 +384,30 @@ class ServerJwtDslTests { val factory = KeyFactory.getInstance("RSA") return factory.generatePublic(spec) as RSAPublicKey } + + private fun alwaysFailWebClient(): WebClient { + return mock(WebClient::class.java) { + when (it.method.name) { + "toString", "equals", "hashCode" -> { + return@mock it.callRealMethod() + } + else -> { + error("method ${it.method.name}(...) invoked") + } + } + }!! + } + + @Order(-2) + open class GlobalErrorHandler : WebExceptionHandler { + override fun handle(serverWebExchange: ServerWebExchange, throwable: Throwable): Mono { + val response = serverWebExchange.response + response.statusCode = HttpStatus.INTERNAL_SERVER_ERROR + response.headers.contentType = MediaType.TEXT_PLAIN + val dataBuffer = response.bufferFactory() + .wrap(ExceptionUtils.getStackTrace(throwable).toByteArray()) + return response.writeWith(Mono.just(dataBuffer)) + } + } } } diff --git a/config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDslTests.kt index bb088954d5c..8b671fc0150 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDslTests.kt @@ -16,8 +16,17 @@ package org.springframework.security.config.web.servlet.oauth2.resourceserver +import com.nimbusds.jose.JWSAlgorithm +import com.nimbusds.jwt.SignedJWT +import com.nimbusds.oauth2.sdk.assertions.jwt.JWTAssertionDetails +import com.nimbusds.oauth2.sdk.assertions.jwt.JWTAssertionFactory +import com.nimbusds.oauth2.sdk.id.Audience +import com.nimbusds.oauth2.sdk.id.Issuer +import com.nimbusds.oauth2.sdk.id.Subject +import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Rule import org.junit.Test +import org.mockito.ArgumentMatchers.anyString import org.mockito.Mockito.* import org.springframework.beans.factory.annotation.Autowired import org.springframework.context.annotation.Bean @@ -28,12 +37,15 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter import org.springframework.security.config.test.SpringTestRule +import org.springframework.security.config.util.AlwaysRethrowAuthenticationEntryPoint import org.springframework.security.config.web.servlet.invoke import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames +import org.springframework.security.oauth2.jose.TestKeys import org.springframework.security.oauth2.jwt.Jwt import org.springframework.security.oauth2.jwt.JwtDecoder import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.get +import org.springframework.web.client.RestOperations /** * Tests for [JwtDsl] @@ -84,6 +96,92 @@ class JwtDslTests { } } + @Test + fun `JWT when custom jwkSetUri than restOperations set`() { + spring.register(CustomJwkSetUriThanRestOperationsConfig::class.java).autowire() + + assertThatThrownBy { + mockMvc.get("/") { + header("Authorization", "Bearer ${createSignetJwt().serialize()}") + } + }.hasRootCauseInstanceOf(IllegalStateException::class.java) + .hasRootCauseMessage("method exchange(...) invoked") + } + + @EnableWebSecurity + open class CustomJwkSetUriThanRestOperationsConfig : WebSecurityConfigurerAdapter() { + override fun configure(http: HttpSecurity) { + http { + oauth2ResourceServer { + jwt { + jwtDecoder = mock(JwtDecoder::class.java) + jwkSetUri = "https://my-jwk-uri" + restOperations = alwaysFailRestOperations() + } + authenticationEntryPoint = AlwaysRethrowAuthenticationEntryPoint() + } + } + } + } + + @Test + fun `JWT when custom restOperations than jwkSetUri set`() { + spring.register(CustomRestOperationsThanJwkSetUriConfig::class.java).autowire() + + assertThatThrownBy { + mockMvc.get("/") { + header("Authorization", "Bearer ${createSignetJwt().serialize()}") + } + }.hasRootCauseInstanceOf(IllegalStateException::class.java) + .hasRootCauseMessage("method exchange(...) invoked") + } + + @EnableWebSecurity + open class CustomRestOperationsThanJwkSetUriConfig : WebSecurityConfigurerAdapter() { + override fun configure(http: HttpSecurity) { + http { + oauth2ResourceServer { + jwt { + jwtDecoder = mock(JwtDecoder::class.java) + restOperations = alwaysFailRestOperations() + jwkSetUri = "https://my-jwk-uri" + } + authenticationEntryPoint = AlwaysRethrowAuthenticationEntryPoint() + } + } + } + } + + @Test + fun `JWT when custom restOperations and jwkSetUri replaced with jwtDecoder`() { + spring.register(CustomRestOperationsAndJwkSetUriReplacedWithJwtDecoderConfig::class.java).autowire() + + assertThatThrownBy { + mockMvc.get("/") { + header("Authorization", "Bearer ${createSignetJwt().serialize()}") + } + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("replaced jwtDecoder") + } + + @EnableWebSecurity + open class CustomRestOperationsAndJwkSetUriReplacedWithJwtDecoderConfig : WebSecurityConfigurerAdapter() { + override fun configure(http: HttpSecurity) { + http { + oauth2ResourceServer { + jwt { + restOperations = alwaysFailRestOperations() + jwkSetUri = "https://my-jwk-uri" + jwtDecoder = mock(JwtDecoder::class.java).also { + `when`(it.decode(anyString())).thenThrow(java.lang.IllegalStateException("replaced jwtDecoder")) + } + } + authenticationEntryPoint = AlwaysRethrowAuthenticationEntryPoint() + } + } + } + } + @Test fun `JWT when custom JWT authentication converter then converter used`() { this.spring.register(CustomJwtAuthenticationConverterConfig::class.java).autowire() @@ -163,4 +261,25 @@ class JwtDslTests { } } } + + companion object { + private fun alwaysFailRestOperations(): RestOperations { + return mock(RestOperations::class.java) { + when (it.method.name) { + "toString", "equals", "hashCode" -> { + return@mock it.callRealMethod() + } + else -> { + error("method ${it.method.name}(...) invoked") + } + } + }!! + } + + private fun createSignetJwt(): SignedJWT { + return JWTAssertionFactory.create( + JWTAssertionDetails(Issuer("issuer"), Subject("sub"), Audience("aud")), + JWSAlgorithm.RS256, TestKeys.privateKey(), "key-id", null) + } + } } diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests-JwkSetUriRestOperations.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests-JwkSetUriRestOperations.xml new file mode 100644 index 00000000000..ee83a504a25 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests-JwkSetUriRestOperations.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/manual/src/docs/asciidoc/_includes/servlet/appendix/namespace.adoc b/docs/manual/src/docs/asciidoc/_includes/servlet/appendix/namespace.adoc index 02754f1780e..07761dbd967 100644 --- a/docs/manual/src/docs/asciidoc/_includes/servlet/appendix/namespace.adoc +++ b/docs/manual/src/docs/asciidoc/_includes/servlet/appendix/namespace.adoc @@ -1187,6 +1187,10 @@ Reference to a `JwtDecoder`. This is a larger component that overrides `jwk-set- * **jwk-set-uri** The JWK Set Uri used to load signing verification keys from an OAuth 2.0 Authorization Server +[[nsa-jwt-rest-operations-ref]] +* **rest-operations-ref** +The reference to bean of RestOperations used to load JWK Set from Uri from an OAuth 2.0 Authorization Server + [[nsa-opaque-token]] ==== Represents an OAuth 2.0 Resource Server that will authorize opaque tokens diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java index da51f71d629..5be5b5fa12e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java @@ -46,6 +46,7 @@ import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.client.RestOperations; import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecretKey; @@ -79,6 +80,7 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory jwsAlgorithmResolver = clientRegistration -> SignatureAlgorithm.RS256; private Function, Map>> claimTypeConverterFactory = clientRegistration -> DEFAULT_CLAIM_TYPE_CONVERTER; + private RestOperations restOperations; /** * Returns the default {@link Converter}'s used for type conversion of claim values for an {@link OidcIdToken}. @@ -153,7 +155,12 @@ private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) { ); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - return withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build(); + final NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(jwkSetUri) + .jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm); + if (restOperations != null) { + builder.restOperations(restOperations); + } + return builder.build(); } else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // @@ -226,4 +233,9 @@ public void setClaimTypeConverterFactory(Function jwsAlgorithmResolver = clientRegistration -> SignatureAlgorithm.RS256; private Function, Map>> claimTypeConverterFactory = clientRegistration -> DEFAULT_CLAIM_TYPE_CONVERTER; + private WebClient webClient; /** * Returns the default {@link Converter}'s used for type conversion of claim values for an {@link OidcIdToken}. @@ -153,7 +155,12 @@ private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistrat ); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - return withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build(); + final NimbusReactiveJwtDecoder.JwkSetUriReactiveJwtDecoderBuilder builder = withJwkSetUri(jwkSetUri) + .jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm); + if (webClient != null) { + builder.webClient(webClient); + } + return builder.build(); } else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // @@ -226,4 +233,8 @@ public void setClaimTypeConverterFactory(Function> typeReference = new ParameterizedTypeReference>() {}; - static Map getConfigurationForOidcIssuerLocation(String oidcIssuerLocation) { - return getConfiguration(oidcIssuerLocation, oidc(URI.create(oidcIssuerLocation))); + static Map getConfigurationForOidcIssuerLocation(String oidcIssuerLocation, + RestOperations restOperations) { + return getConfiguration(oidcIssuerLocation, + restOperations == null ? DEFAULT_REST : restOperations, + oidc(URI.create(oidcIssuerLocation))); } - static Map getConfigurationForIssuerLocation(String issuer) { + static Map getConfigurationForIssuerLocation(String issuer, + RestOperations restOperations) { URI uri = URI.create(issuer); - return getConfiguration(issuer, oidc(uri), oidcRfc8414(uri), oauth(uri)); + return getConfiguration(issuer, + restOperations == null ? DEFAULT_REST : restOperations, + oidc(uri), + oidcRfc8414(uri), + oauth(uri)); } static void validateIssuer(Map configuration, String issuer) { @@ -63,13 +72,13 @@ static void validateIssuer(Map configuration, String issuer) { } } - private static Map getConfiguration(String issuer, URI... uris) { + private static Map getConfiguration(String issuer, RestOperations restOperations, URI... uris) { String errorMessage = "Unable to resolve the Configuration with the provided Issuer of " + "\"" + issuer + "\""; for (URI uri : uris) { try { RequestEntity request = RequestEntity.get(uri).build(); - ResponseEntity> response = rest.exchange(request, typeReference); + ResponseEntity> response = restOperations.exchange(request, typeReference); Map configuration = response.getBody(); if (configuration.get("jwks_uri") == null) { diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java index 0d2f2331987..890b33b4b97 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java @@ -19,6 +19,7 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.util.Assert; +import org.springframework.web.client.RestOperations; import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; @@ -46,9 +47,25 @@ public final class JwtDecoders { * @return a {@link JwtDecoder} that was initialized by the OpenID Provider Configuration. */ public static JwtDecoder fromOidcIssuerLocation(String oidcIssuerLocation) { + return fromOidcIssuerLocation(oidcIssuerLocation, null); + } + + /** + * Creates a {@link JwtDecoder} using the provided + * Issuer by making an + * OpenID Provider + * Configuration Request and using the values in the + * OpenID + * Provider Configuration Response to initialize the {@link JwtDecoder}. + * + * @param oidcIssuerLocation the Issuer + * @param restOperations used as http-client (if not specified default http-client is used) + * @return a {@link JwtDecoder} that was initialized by the OpenID Provider Configuration. + */ + public static JwtDecoder fromOidcIssuerLocation(String oidcIssuerLocation, RestOperations restOperations) { Assert.hasText(oidcIssuerLocation, "oidcIssuerLocation cannot be empty"); - Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForOidcIssuerLocation(oidcIssuerLocation); - return withProviderConfiguration(configuration, oidcIssuerLocation); + Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForOidcIssuerLocation(oidcIssuerLocation, restOperations); + return withProviderConfiguration(configuration, oidcIssuerLocation, restOperations); } /** @@ -84,9 +101,46 @@ public static JwtDecoder fromOidcIssuerLocation(String oidcIssuerLocation) { * @return a {@link JwtDecoder} that was initialized by one of the described endpoints */ public static JwtDecoder fromIssuerLocation(String issuer) { + return fromIssuerLocation(issuer, null); + } + + /** + * Creates a {@link JwtDecoder} using the provided + * Issuer by querying + * three different discovery endpoints serially, using the values in the first successful response to + * initialize. If an endpoint returns anything other than a 200 or a 4xx, the method will exit without + * attempting subsequent endpoints. + * + * The three endpoints are computed as follows, given that the {@code issuer} is composed of a {@code host} + * and a {@code path}: + * + *
    + *
  1. + * {@code host/.well-known/openid-configuration/path}, as defined in + * RFC 8414's Compatibility Notes. + *
  2. + *
  3. + * {@code issuer/.well-known/openid-configuration}, as defined in + * + * OpenID Provider Configuration. + *
  4. + *
  5. + * {@code host/.well-known/oauth-authorization-server/path}, as defined in + * Authorization Server Metadata Request. + *
  6. + *
+ * + * Note that the second endpoint is the equivalent of calling + * {@link JwtDecoders#fromOidcIssuerLocation(String)} + * + * @param issuer the Issuer + * @param restOperations used as http-client (if not specified default http-client is used) + * @return a {@link JwtDecoder} that was initialized by one of the described endpoints + */ + public static JwtDecoder fromIssuerLocation(String issuer, RestOperations restOperations) { Assert.hasText(issuer, "issuer cannot be empty"); - Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForIssuerLocation(issuer); - return withProviderConfiguration(configuration, issuer); + Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForIssuerLocation(issuer, restOperations); + return withProviderConfiguration(configuration, issuer, restOperations); } /** @@ -97,12 +151,20 @@ public static JwtDecoder fromIssuerLocation(String issuer) { * * @param configuration the configuration values * @param issuer the Issuer + * @param restOperations used as http-client (if not specified default http-client is used) * @return {@link JwtDecoder} */ - private static JwtDecoder withProviderConfiguration(Map configuration, String issuer) { + private static JwtDecoder withProviderConfiguration( + Map configuration, + String issuer, + RestOperations restOperations) { JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer); OAuth2TokenValidator jwtValidator = JwtValidators.createDefaultWithIssuer(issuer); - NimbusJwtDecoder jwtDecoder = withJwkSetUri(configuration.get("jwks_uri").toString()).build(); + final NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(configuration.get("jwks_uri").toString()); + if (restOperations != null) { + builder.restOperations(restOperations); + } + NimbusJwtDecoder jwtDecoder = builder.build(); jwtDecoder.setJwtValidator(jwtValidator); return jwtDecoder; diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 44daabd13c7..cb6597bc5c1 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -215,9 +215,10 @@ public static SecretKeyJwtDecoderBuilder withSecretKey(SecretKey secretKey) { * JWK Set uri. */ public static final class JwkSetUriJwtDecoderBuilder { - private String jwkSetUri; - private Set signatureAlgorithms = new HashSet<>(); - private RestOperations restOperations = new RestTemplate(); + private static final RestTemplate DEFAULT_REST = new RestTemplate(); + private final String jwkSetUri; + private final Set signatureAlgorithms = new HashSet<>(); + private RestOperations restOperations = DEFAULT_REST; private Cache cache; private JwkSetUriJwtDecoderBuilder(String jwkSetUri) { diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index fa82a3899c9..03ef8c17e69 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -87,6 +87,9 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { * * @param jwkSetUrl the JSON Web Key (JWK) Set {@code URL} */ + public NimbusReactiveJwtDecoder(String jwkSetUrl, WebClient webClient) { + this(withJwkSetUri(jwkSetUrl).webClient(webClient).processor()); + } public NimbusReactiveJwtDecoder(String jwkSetUrl) { this(withJwkSetUri(jwkSetUrl).processor()); } @@ -242,9 +245,10 @@ public static JwkSourceReactiveJwtDecoderBuilder withJwkSource(Function signatureAlgorithms = new HashSet<>(); - private WebClient webClient = WebClient.create(); + private WebClient webClient = DEFAULT_WEBCLIENT; private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java index d062b515cfb..af4cde788ad 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java @@ -19,6 +19,8 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.util.Assert; +import org.springframework.web.client.RestOperations; +import org.springframework.web.reactive.function.client.WebClient; import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri; @@ -45,9 +47,28 @@ public final class ReactiveJwtDecoders { * @return a {@link ReactiveJwtDecoder} that was initialized by the OpenID Provider Configuration. */ public static ReactiveJwtDecoder fromOidcIssuerLocation(String oidcIssuerLocation) { + return fromOidcIssuerLocation(oidcIssuerLocation, null, null); + } + + /** + * Creates a {@link ReactiveJwtDecoder} using the provided + * Issuer by making an + * OpenID Provider + * Configuration Request and using the values in the + * OpenID + * Provider Configuration Response to initialize the {@link ReactiveJwtDecoder}. + * + * @param oidcIssuerLocation the Issuer + * @param webClient used as http-client (if not specified default http-client is used) for reactive operations + * @param restOperations used as http-client (if not specified default http-client is used) + * @return a {@link ReactiveJwtDecoder} that was initialized by the OpenID Provider Configuration. + */ + public static ReactiveJwtDecoder fromOidcIssuerLocation(String oidcIssuerLocation, + RestOperations restOperations, + WebClient webClient) { Assert.hasText(oidcIssuerLocation, "oidcIssuerLocation cannot be empty"); - Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForOidcIssuerLocation(oidcIssuerLocation); - return withProviderConfiguration(configuration, oidcIssuerLocation); + Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForOidcIssuerLocation(oidcIssuerLocation, restOperations); + return withProviderConfiguration(configuration, oidcIssuerLocation, webClient); } /** @@ -83,9 +104,49 @@ public static ReactiveJwtDecoder fromOidcIssuerLocation(String oidcIssuerLocatio * @return a {@link ReactiveJwtDecoder} that was initialized by one of the described endpoints */ public static ReactiveJwtDecoder fromIssuerLocation(String issuer) { + return fromIssuerLocation(issuer, null, null); + } + + /** + * Creates a {@link ReactiveJwtDecoder} using the provided + * Issuer by querying + * three different discovery endpoints serially, using the values in the first successful response to + * initialize. If an endpoint returns anything other than a 200 or a 4xx, the method will exit without + * attempting subsequent endpoints. + * + * The three endpoints are computed as follows, given that the {@code issuer} is composed of a {@code host} + * and a {@code path}: + * + *
    + *
  1. + * {@code host/.well-known/openid-configuration/path}, as defined in + * RFC 8414's Compatibility Notes. + *
  2. + *
  3. + * {@code issuer/.well-known/openid-configuration}, as defined in + * + * OpenID Provider Configuration. + *
  4. + *
  5. + * {@code host/.well-known/oauth-authorization-server/path}, as defined in + * Authorization Server Metadata Request. + *
  6. + *
+ * + * Note that the second endpoint is the equivalent of calling + * {@link ReactiveJwtDecoders#fromOidcIssuerLocation(String)} + * + * @param issuer the Issuer + * @param webClient used as http-client (if not specified default http-client is used) for reactive operations + * @param restOperations used as http-client (if not specified default http-client is used) + * @return a {@link ReactiveJwtDecoder} that was initialized by one of the described endpoints + */ + public static ReactiveJwtDecoder fromIssuerLocation(String issuer, + RestOperations restOperations, + WebClient webClient) { Assert.hasText(issuer, "issuer cannot be empty"); - Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForIssuerLocation(issuer); - return withProviderConfiguration(configuration, issuer); + Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForIssuerLocation(issuer, restOperations); + return withProviderConfiguration(configuration, issuer, webClient); } /** @@ -95,13 +156,21 @@ public static ReactiveJwtDecoder fromIssuerLocation(String issuer) { * Response. * * @param configuration the configuration values - * @param issuer the Issuer + * @param issuer the Issuer + * @param webClient used as http-client (if not specified default http-client is used) for reactive operations * @return {@link ReactiveJwtDecoder} */ - private static ReactiveJwtDecoder withProviderConfiguration(Map configuration, String issuer) { + private static ReactiveJwtDecoder withProviderConfiguration(Map configuration, + String issuer, + WebClient webClient) { JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer); OAuth2TokenValidator jwtValidator = JwtValidators.createDefaultWithIssuer(issuer); - NimbusReactiveJwtDecoder jwtDecoder = withJwkSetUri(configuration.get("jwks_uri").toString()).build(); + final NimbusReactiveJwtDecoder.JwkSetUriReactiveJwtDecoderBuilder builder = + withJwkSetUri(configuration.get("jwks_uri").toString()); + if (webClient != null) { + builder.webClient(webClient); + } + NimbusReactiveJwtDecoder jwtDecoder = builder.build(); jwtDecoder.setJwtValidator(jwtValidator); return jwtDecoder; diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java index 555d92b347b..fd0cd1f3e4a 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java @@ -96,7 +96,7 @@ private static RSAPublicKey publicKey() { public static final RSAPrivateKey DEFAULT_PRIVATE_KEY = privateKey(); - private static RSAPrivateKey privateKey() { + public static RSAPrivateKey privateKey() { PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(Base64.getDecoder().decode(DEFAULT_RSA_PRIVATE_KEY)); try { return (RSAPrivateKey) kf.generatePrivate(spec); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java index 8fa7ee58690..7269068c4b9 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java @@ -16,14 +16,24 @@ package org.springframework.security.oauth2.jwt; import java.net.URI; +import java.security.interfaces.RSAPrivateKey; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.oauth2.sdk.assertions.jwt.JWTAssertionDetails; +import com.nimbusds.oauth2.sdk.assertions.jwt.JWTAssertionFactory; +import com.nimbusds.oauth2.sdk.id.Audience; +import com.nimbusds.oauth2.sdk.id.Issuer; +import com.nimbusds.oauth2.sdk.id.Subject; import okhttp3.HttpUrl; import okhttp3.mockwebserver.Dispatcher; import okhttp3.mockwebserver.MockResponse; @@ -33,12 +43,22 @@ import org.junit.Before; import org.junit.Test; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.web.client.RestOperations; import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; /** * Tests for {@link JwtDecoders} @@ -130,6 +150,47 @@ public void issuerWhenContainsTrailingSlashThenSuccess() { assertThat(this.issuer).endsWith("/"); } + @Test + public void fromOidcIssuerLocationWithRestOperations() throws JsonProcessingException, JOSEException { + withRestOperationsTestFromIssuerLocation(restOperations -> + JwtDecoders.fromOidcIssuerLocation(this.issuer, restOperations)); + } + + @Test + public void fromIssuerLocationWithRestOperations() throws JsonProcessingException, JOSEException { + withRestOperationsTestFromIssuerLocation(restOperations -> + JwtDecoders.fromIssuerLocation(this.issuer, restOperations)); + } + + private void withRestOperationsTestFromIssuerLocation(Function invoker) throws JsonProcessingException, JOSEException { + final ParameterizedTypeReference> springTypeReference = + new ParameterizedTypeReference>() { + }; + final TypeReference> jacksonTypeReference = new TypeReference>() { + }; + + final String responseStr = String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer, this.issuer); + final ResponseEntity> response = ResponseEntity.ok(new ObjectMapper() + .readValue(responseStr, jacksonTypeReference)); + final RestOperations restOperations = mock(RestOperations.class); + when(restOperations.exchange(any(), eq(springTypeReference))).thenReturn(response); + when(restOperations.exchange(any(), eq(String.class))).thenThrow(new IllegalStateException("/.well-known/jwks.json call")); + + final JwtDecoder jwtDecoder = invoker.apply(restOperations); + + final RSAPrivateKey privateKey = TestKeys.privateKey(); + final SignedJWT signedJWT = JWTAssertionFactory.create( + new JWTAssertionDetails(new Issuer("issuer"), new Subject("sub"), new Audience("aud")), + JWSAlgorithm.RS256, privateKey, "key-id", null); + assertThatCode(() -> jwtDecoder.decode(signedJWT.serialize())) + .hasRootCauseInstanceOf(IllegalStateException.class) + .hasRootCauseMessage("/.well-known/jwks.json call"); + + verify(restOperations).exchange(any(), eq(springTypeReference)); + verify(restOperations).exchange(any(), eq(String.class)); + verifyNoMoreInteractions(restOperations); + } + @Test public void issuerWhenOidcFallbackContainsTrailingSlashThenSuccess() { this.issuer += "/"; diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index a6ff351287f..306fd7232b0 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -524,6 +524,23 @@ public void decodeWhenCacheIsConfiguredAndValueLoaderErrorsThenThrowsJwtExceptio .hasMessageContaining("An error occurred while attempting to decode the Jwt"); } + @Test + public void decoderBuildWithCustomRestOperations() throws Exception { + final RestOperations restOperations = mock(RestOperations.class); + when(restOperations.exchange(any(), eq(String.class))).thenThrow(new IllegalStateException("custom restOperations")); + + NimbusJwtDecoder jwtDecoder = withJwkSetUri("http://foo/.well-known/jwks.json") + .restOperations(restOperations) + .build(); + + assertThatThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)) + .hasRootCauseInstanceOf(IllegalStateException.class) + .hasRootCauseMessage("custom restOperations"); + + verify(restOperations).exchange(any(), eq(String.class)); + verifyNoMoreInteractions(restOperations); + } + private RSAPublicKey key() throws InvalidKeySpecException { byte[] decoded = Base64.getDecoder().decode(VERIFY_KEY.getBytes()); EncodedKeySpec spec = new X509EncodedKeySpec(decoded); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index 3109031fdb7..dc44bcdd3d6 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -429,6 +429,19 @@ public void jwsKeySelectorWhenMultipleAlgorithmThenReturnsCompositeSelector() { .isTrue(); } + @Test + public void jwtDecoderWithCustomWebClient() { + final WebClient webClient = WebClient.builder() + .exchangeFunction(request -> Mono.error(new IllegalStateException("custom webClient"))) + .build(); + + NimbusReactiveJwtDecoder decoder = withJwkSetUri(this.jwkSetUri).webClient(webClient).build(); + + assertThatCode(() -> decoder.decode(messageReadToken).block()) + .hasRootCauseInstanceOf(IllegalStateException.class) + .hasRootCauseMessage("custom webClient"); + } + private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) throws Exception { SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.parse(jwsAlgorithm.getName())), claimsSet); JWSSigner signer = new MACSigner(secretKey); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java index 3e6b411fc51..32d282608b8 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.jwt; import java.net.URI; +import java.security.interfaces.RSAPrivateKey; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -24,6 +25,14 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.oauth2.sdk.assertions.jwt.JWTAssertionDetails; +import com.nimbusds.oauth2.sdk.assertions.jwt.JWTAssertionFactory; +import com.nimbusds.oauth2.sdk.id.Audience; +import com.nimbusds.oauth2.sdk.id.Issuer; +import com.nimbusds.oauth2.sdk.id.Subject; import okhttp3.HttpUrl; import okhttp3.mockwebserver.Dispatcher; import okhttp3.mockwebserver.MockResponse; @@ -33,11 +42,21 @@ import org.junit.Before; import org.junit.Test; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.web.client.RestOperations; +import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.util.UriComponentsBuilder; +import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link ReactiveJwtDecoders} @@ -106,6 +125,36 @@ public void issuerWhenResponseIsTypicalThenReturnedDecoderValidatesIssuer() { .hasMessageContaining("The iss claim is not valid"); } + @Test + public void fromOidcIssuerLocationWithWebClient() throws JsonProcessingException, JOSEException { + final ParameterizedTypeReference> springTypeReference = + new ParameterizedTypeReference>() { + }; + final TypeReference> jacksonTypeReference = new TypeReference>() { + }; + + final String responseStr = String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer, this.issuer); + final ResponseEntity> response = ResponseEntity.ok(new ObjectMapper() + .readValue(responseStr, jacksonTypeReference)); + final RestOperations restOperations = mock(RestOperations.class); + when(restOperations.exchange(any(), eq(springTypeReference))).thenReturn(response); + + final WebClient webClient = WebClient.builder() + .exchangeFunction(request -> Mono.error(new IllegalStateException("custom webClient"))) + .build(); + + ReactiveJwtDecoder decoder = ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer, restOperations, webClient); + + final RSAPrivateKey privateKey = TestKeys.privateKey(); + final SignedJWT signedJWT = JWTAssertionFactory.create( + new JWTAssertionDetails(new Issuer("issuer"), new Subject("sub"), new Audience("aud")), + JWSAlgorithm.RS256, privateKey, "key-id", null); + + assertThatCode(() -> decoder.decode(signedJWT.serialize()).block()) + .hasRootCauseInstanceOf(IllegalStateException.class) + .hasRootCauseMessage("custom webClient"); + } + @Test public void issuerWhenOidcFallbackResponseIsTypicalThenReturnedDecoderValidatesIssuer() { prepareConfigurationResponseOidc(); diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolver.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolver.java index 97cc3d3fe9f..38d6d7e253e 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolver.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolver.java @@ -37,6 +37,7 @@ import org.springframework.security.oauth2.server.resource.web.BearerTokenResolver; import org.springframework.security.oauth2.server.resource.web.DefaultBearerTokenResolver; import org.springframework.util.Assert; +import org.springframework.web.client.RestOperations; /** * An implementation of {@link AuthenticationManagerResolver} that resolves a JWT-based {@link AuthenticationManager} @@ -63,7 +64,17 @@ public final class JwtIssuerAuthenticationManagerResolver implements Authenticat * @param trustedIssuers a list of trusted issuers */ public JwtIssuerAuthenticationManagerResolver(String... trustedIssuers) { - this(Arrays.asList(trustedIssuers)); + this(null, trustedIssuers); + } + + /** + * Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided parameters + * + * @param restOperations used as http-client (if not specified default http-client is used) + * @param trustedIssuers a whitelist of trusted issuers + */ + public JwtIssuerAuthenticationManagerResolver(RestOperations restOperations, String... trustedIssuers) { + this(Arrays.asList(trustedIssuers), restOperations); } /** @@ -72,10 +83,20 @@ public JwtIssuerAuthenticationManagerResolver(String... trustedIssuers) { * @param trustedIssuers a list of trusted issuers */ public JwtIssuerAuthenticationManagerResolver(Collection trustedIssuers) { + this(trustedIssuers, null); + } + + /** + * Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided parameters + * + * @param trustedIssuers a whitelist of trusted issuers + * @param restOperations used as http-client (if not specified default http-client is used) + */ + public JwtIssuerAuthenticationManagerResolver(Collection trustedIssuers, RestOperations restOperations) { Assert.notEmpty(trustedIssuers, "trustedIssuers cannot be empty"); this.issuerAuthenticationManagerResolver = new TrustedIssuerJwtAuthenticationManagerResolver - (Collections.unmodifiableCollection(trustedIssuers)::contains); + (Collections.unmodifiableCollection(trustedIssuers)::contains, restOperations); } /** @@ -143,16 +164,18 @@ private static class TrustedIssuerJwtAuthenticationManagerResolver private final Map authenticationManagers = new ConcurrentHashMap<>(); private final Predicate trustedIssuer; + private final RestOperations restOperations; - TrustedIssuerJwtAuthenticationManagerResolver(Predicate trustedIssuer) { + TrustedIssuerJwtAuthenticationManagerResolver(Predicate trustedIssuer, RestOperations restOperations) { this.trustedIssuer = trustedIssuer; + this.restOperations = restOperations; } @Override public AuthenticationManager resolve(String issuer) { if (this.trustedIssuer.test(issuer)) { return this.authenticationManagers.computeIfAbsent(issuer, k -> { - JwtDecoder jwtDecoder = JwtDecoders.fromIssuerLocation(issuer); + JwtDecoder jwtDecoder = JwtDecoders.fromIssuerLocation(issuer, restOperations); return new JwtAuthenticationProvider(jwtDecoder)::authenticate; }); } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolver.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolver.java index 0328d5ae3b8..45ee8ce04ac 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolver.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolver.java @@ -38,6 +38,8 @@ import org.springframework.security.oauth2.server.resource.InvalidBearerTokenException; import org.springframework.security.oauth2.server.resource.web.server.ServerBearerTokenAuthenticationConverter; import org.springframework.util.Assert; +import org.springframework.web.client.RestOperations; +import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.server.ServerWebExchange; /** @@ -69,7 +71,20 @@ public final class JwtIssuerReactiveAuthenticationManagerResolver * @param trustedIssuers a list of trusted issuers */ public JwtIssuerReactiveAuthenticationManagerResolver(String... trustedIssuers) { - this(Arrays.asList(trustedIssuers)); + this(null, null, trustedIssuers); + } + + /** + * Construct a {@link JwtIssuerReactiveAuthenticationManagerResolver} using the provided parameters + * + * @param restOperations used as http-client (if not specified default http-client is used) + * @param webClient used as http-client (if not specified default http-client is used) for reactive operations + * @param trustedIssuers a whitelist of trusted issuers + */ + public JwtIssuerReactiveAuthenticationManagerResolver(RestOperations restOperations, + WebClient webClient, + String... trustedIssuers) { + this(Arrays.asList(trustedIssuers), restOperations, webClient); } /** @@ -78,9 +93,25 @@ public JwtIssuerReactiveAuthenticationManagerResolver(String... trustedIssuers) * @param trustedIssuers a collection of trusted issuers */ public JwtIssuerReactiveAuthenticationManagerResolver(Collection trustedIssuers) { + this(trustedIssuers, null, null); + } + + /** + * Construct a {@link JwtIssuerReactiveAuthenticationManagerResolver} using the provided parameters + * + * @param trustedIssuers a whitelist of trusted issuers + * @param restOperations used as http-client (if not specified default http-client is used) + * @param webClient used as http-client (if not specified default http-client is used) for reactive operations + */ + public JwtIssuerReactiveAuthenticationManagerResolver(Collection trustedIssuers, + RestOperations restOperations, + WebClient webClient) { Assert.notEmpty(trustedIssuers, "trustedIssuers cannot be empty"); this.issuerAuthenticationManagerResolver = - new TrustedIssuerJwtAuthenticationManagerResolver(new ArrayList<>(trustedIssuers)::contains); + new TrustedIssuerJwtAuthenticationManagerResolver( + new ArrayList<>(trustedIssuers)::contains, + restOperations, + webClient); } /** @@ -155,9 +186,15 @@ private static class TrustedIssuerJwtAuthenticationManagerResolver private final Map> authenticationManagers = new ConcurrentHashMap<>(); private final Predicate trustedIssuer; + private final RestOperations restOperations; + private final WebClient webClient; - TrustedIssuerJwtAuthenticationManagerResolver(Predicate trustedIssuer) { + TrustedIssuerJwtAuthenticationManagerResolver(Predicate trustedIssuer, + RestOperations restOperations, + WebClient webClient) { this.trustedIssuer = trustedIssuer; + this.restOperations = restOperations; + this.webClient = webClient; } @Override @@ -167,7 +204,7 @@ public Mono resolve(String issuer) { } return this.authenticationManagers.computeIfAbsent(issuer, k -> Mono.fromCallable(() -> - new JwtReactiveAuthenticationManager(ReactiveJwtDecoders.fromIssuerLocation(k)) + new JwtReactiveAuthenticationManager(ReactiveJwtDecoders.fromIssuerLocation(k, restOperations, webClient)) ) .subscribeOn(Schedulers.boundedElastic()) .cache()); diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolverTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolverTests.java index 21df5189fe9..b9c11715203 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolverTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolverTests.java @@ -32,16 +32,20 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.Test; - +import org.springframework.core.ParameterizedTypeReference; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManagerResolver; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; /** @@ -179,6 +183,26 @@ public void constructorWhenNullAuthenticationManagerResolverThenException() { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void resolverWithCustomRestOperations() throws Exception { + JWSObject jws = new JWSObject(new JWSHeader(JWSAlgorithm.RS256), + new Payload(new JSONObject(Collections.singletonMap(ISS, "issuer")))); + jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY)); + + final RestOperations restOperations = mock(RestOperations.class); + when(restOperations.exchange(any(), any(ParameterizedTypeReference.class))).thenThrow(new IllegalStateException("custom restOperations")); + + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = + new JwtIssuerAuthenticationManagerResolver(restOperations, "issuer"); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Bearer " + jws.serialize()); + + assertThatThrownBy(() -> authenticationManagerResolver.resolve(request)) + .hasRootCauseInstanceOf(IllegalStateException.class) + .hasRootCauseMessage("custom restOperations"); + } + private String jwt(String claim, String value) { PlainJWT jwt = new PlainJWT(new JWTClaimsSet.Builder().claim(claim, value).build()); return jwt.serialize(); diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolverTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolverTests.java index d694707096f..470b4450a4b 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolverTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolverTests.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; +import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSObject; @@ -32,18 +33,26 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.Test; -import reactor.core.publisher.Mono; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; /** @@ -67,13 +76,11 @@ public void resolveWhenUsingTrustedIssuerThenReturnsAuthenticationManager() thro .setResponseCode(200) .setHeader("Content-Type", "application/json") .setBody(String.format(DEFAULT_RESPONSE_TEMPLATE, issuer, issuer))); - JWSObject jws = new JWSObject(new JWSHeader(JWSAlgorithm.RS256), - new Payload(new JSONObject(Collections.singletonMap(ISS, issuer)))); - jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY)); + String jwt = jwt(issuer); JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerReactiveAuthenticationManagerResolver(issuer); - MockServerWebExchange exchange = withBearerToken(jws.serialize()); + MockServerWebExchange exchange = withBearerToken(jwt); ReactiveAuthenticationManager authenticationManager = authenticationManagerResolver.resolve(exchange).block(); @@ -173,6 +180,52 @@ public void constructorWhenNullAuthenticationManagerResolverThenException() { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void resolverWithCustomRestOperations() throws Exception { + String jwt = jwt("issuer"); + + final WebClient webClient = WebClient.builder() + .exchangeFunction(request -> Mono.error(new IllegalStateException("custom webClient"))) + .build(); + final RestOperations restOperations = mock(RestOperations.class); + when(restOperations.exchange(any(), any(ParameterizedTypeReference.class))).thenThrow(new IllegalStateException("custom restOperations")); + + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = + new JwtIssuerReactiveAuthenticationManagerResolver(restOperations, webClient, "issuer"); + MockServerWebExchange exchange = withBearerToken(jwt); + + assertThatThrownBy(() -> authenticationManagerResolver.resolve(exchange).block()) + .hasRootCauseInstanceOf(IllegalStateException.class) + .hasRootCauseMessage("custom restOperations"); + } + + @Test + public void resolverWithCustomWebClient() throws Exception { + try (MockWebServer server = new MockWebServer()) { + String issuer = server.url("").toString(); + server.enqueue(new MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody(String.format(DEFAULT_RESPONSE_TEMPLATE, issuer, issuer))); + String jwt = jwt(issuer); + + final WebClient webClient = WebClient.builder() + .exchangeFunction(request -> Mono.error(new IllegalStateException("custom webClient"))) + .build(); + + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = + new JwtIssuerReactiveAuthenticationManagerResolver(new RestTemplate(), webClient, issuer); + MockServerWebExchange exchange = withBearerToken(jwt); + + assertThatThrownBy(() -> + authenticationManagerResolver.resolve(exchange) + .flatMap(authenticationManager -> authenticationManager.authenticate(new BearerTokenAuthenticationToken(jwt))) + .block() + ).hasRootCauseInstanceOf(IllegalStateException.class) + .hasRootCauseMessage("custom webClient"); + } + } + private String jwt(String claim, String value) { PlainJWT jwt = new PlainJWT(new JWTClaimsSet.Builder().claim(claim, value).build()); return jwt.serialize(); @@ -183,4 +236,15 @@ private MockServerWebExchange withBearerToken(String token) { .header("Authorization", "Bearer " + token).build(); return MockServerWebExchange.from(request); } + + private static String jwt(String issuer) { + JWSObject jws = new JWSObject(new JWSHeader(JWSAlgorithm.RS256), + new Payload(new JSONObject(Collections.singletonMap(ISS, issuer)))); + try { + jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY)); + return jws.serialize(); + } catch (JOSEException e) { + throw new IllegalStateException(e); + } + } }