diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index ad469d17dda..c358a9bbe15 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -16,6 +16,7 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.client; import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.core.ResolvableType; import org.springframework.security.authentication.AuthenticationProvider; @@ -55,6 +56,7 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; @@ -488,6 +490,10 @@ public void init(B http) throws Exception { OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider = new OidcAuthorizationCodeAuthenticationProvider(accessTokenResponseClient, oidcUserService); + JwtDecoderFactory jwtDecoderFactory = this.getJwtDecoderFactoryBean(); + if (jwtDecoderFactory != null) { + oidcAuthorizationCodeAuthenticationProvider.setJwtDecoderFactory(jwtDecoderFactory); + } if (userAuthoritiesMapper != null) { oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); } @@ -541,6 +547,19 @@ protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingU return new AntPathRequestMatcher(loginProcessingUrl); } + @SuppressWarnings("unchecked") + private JwtDecoderFactory getJwtDecoderFactoryBean() { + ResolvableType type = ResolvableType.forClassWithGenerics(JwtDecoderFactory.class, ClientRegistration.class); + String[] names = this.getBuilder().getSharedObject(ApplicationContext.class).getBeanNamesForType(type); + if (names.length > 1) { + throw new NoUniqueBeanDefinitionException(type, names); + } + if (names.length == 1) { + return (JwtDecoderFactory) this.getBuilder().getSharedObject(ApplicationContext.class).getBean(names[0]); + } + return null; + } + private GrantedAuthoritiesMapper getGrantedAuthoritiesMapper() { GrantedAuthoritiesMapper grantedAuthoritiesMapper = this.getBuilder().getSharedObject(GrantedAuthoritiesMapper.class); diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 39acc42793b..ec123abe664 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -16,28 +16,6 @@ package org.springframework.security.config.web.server; -import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; -import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.match; -import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.notMatch; - -import java.io.IOException; -import java.io.PrintWriter; -import java.io.StringWriter; -import java.security.interfaces.RSAPublicKey; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.UUID; - -import reactor.core.publisher.Mono; -import reactor.util.context.Context; - import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.core.Ordered; @@ -55,6 +33,8 @@ import org.springframework.security.authorization.ReactiveAuthorizationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; @@ -80,6 +60,7 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter; import org.springframework.security.oauth2.server.resource.authentication.JwtReactiveAuthenticationManager; import org.springframework.security.oauth2.server.resource.authentication.ReactiveJwtAuthenticationConverterAdapter; @@ -92,6 +73,7 @@ import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.ServerAuthenticationEntryPoint; import org.springframework.security.web.server.WebFilterExchange; +import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilter; import org.springframework.security.web.server.authentication.AuthenticationWebFilter; import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint; @@ -159,9 +141,27 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; -import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilter; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.AuthorityUtils; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.security.interfaces.RSAPublicKey; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.function.Function; + +import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; +import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.match; +import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.notMatch; /** * A {@link ServerHttpSecurity} is similar to Spring Security's {@code HttpSecurity} but for WebFlux. @@ -618,7 +618,14 @@ private ReactiveAuthenticationManager createDefault() { boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent( "org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); if (oidcAuthenticationProviderEnabled) { - OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService()); + OidcAuthorizationCodeReactiveAuthenticationManager oidc = + new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService()); + ResolvableType type = ResolvableType.forClassWithGenerics( + ReactiveJwtDecoderFactory.class, ClientRegistration.class); + ReactiveJwtDecoderFactory jwtDecoderFactory = getBeanOrNull(type); + if (jwtDecoderFactory != null) { + oidc.setJwtDecoderFactory(jwtDecoderFactory); + } result = new DelegatingReactiveAuthenticationManager(oidc, result); } return result; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index 017d92a4eae..555f504dd5c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -19,11 +19,12 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.springframework.beans.PropertyAccessorFactory; +import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; import org.springframework.http.MediaType; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; @@ -50,7 +51,6 @@ import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; -import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -66,6 +66,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; @@ -81,6 +82,7 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -369,8 +371,7 @@ public void oauth2LoginWithCustomLoginPageThenRedirectCustomLoginPage() throws E @Test public void oidcLogin() throws Exception { // setup application context - loadConfig(OAuth2LoginConfig.class); - registerJwtDecoder(); + loadConfig(OAuth2LoginConfig.class, JwtDecoderFactoryConfig.class); // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); @@ -396,8 +397,7 @@ public void oidcLogin() throws Exception { @Test public void oidcLoginCustomWithConfigurer() throws Exception { // setup application context - loadConfig(OAuth2LoginConfigCustomWithConfigurer.class); - registerJwtDecoder(); + loadConfig(OAuth2LoginConfigCustomWithConfigurer.class, JwtDecoderFactoryConfig.class); // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); @@ -423,8 +423,7 @@ public void oidcLoginCustomWithConfigurer() throws Exception { @Test public void oidcLoginCustomWithBeanRegistration() throws Exception { // setup application context - loadConfig(OAuth2LoginConfigCustomWithBeanRegistration.class); - registerJwtDecoder(); + loadConfig(OAuth2LoginConfigCustomWithBeanRegistration.class, JwtDecoderFactoryConfig.class); // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); @@ -447,6 +446,15 @@ public void oidcLoginCustomWithBeanRegistration() throws Exception { assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OIDC_USER"); } + @Test + public void oidcLoginCustomWithNoUniqueJwtDecoderFactory() { + assertThatThrownBy(() -> loadConfig(OAuth2LoginConfig.class, NoUniqueJwtDecoderFactoryConfig.class)) + .hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) + .hasMessageContaining("No qualifying bean of type " + + "'org.springframework.security.oauth2.jwt.JwtDecoderFactory' " + + "available: expected single matching bean but found 2: jwtDecoderFactory1,jwtDecoderFactory2"); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -455,25 +463,6 @@ private void loadConfig(Class... configs) { this.context = applicationContext; } - private void registerJwtDecoder() { - JwtDecoder decoder = token -> { - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.SUB, "sub123"); - claims.put(IdTokenClaimNames.ISS, "http://localhost/iss"); - claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d")); - claims.put(IdTokenClaimNames.AZP, "clientId"); - return new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600), - Collections.singletonMap("header1", "value1"), claims); - }; - this.springSecurityFilterChain.getFilters("/login/oauth2/code/google").stream() - .filter(OAuth2LoginAuthenticationFilter.class::isInstance) - .findFirst() - .ifPresent(filter -> PropertyAccessorFactory.forDirectFieldAccess(filter) - .setPropertyValue( - "authenticationManager.providers[2].jwtDecoders['google']", - decoder)); - } - private OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(String... scopes) { return this.createOAuth2AuthorizationRequest(GOOGLE_CLIENT_REGISTRATION, scopes); } @@ -632,6 +621,43 @@ HttpSessionOAuth2AuthorizationRequestRepository oauth2AuthorizationRequestReposi } } + @Configuration + static class JwtDecoderFactoryConfig { + + @Bean + JwtDecoderFactory jwtDecoderFactory() { + return clientRegistration -> getJwtDecoder(); + } + + private static JwtDecoder getJwtDecoder() { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, "sub123"); + claims.put(IdTokenClaimNames.ISS, "http://localhost/iss"); + claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d")); + claims.put(IdTokenClaimNames.AZP, "clientId"); + Jwt jwt = new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600), + Collections.singletonMap("header1", "value1"), claims); + JwtDecoder jwtDecoder = mock(JwtDecoder.class); + when(jwtDecoder.decode(any())).thenReturn(jwt); + return jwtDecoder; + } + } + + @Configuration + static class NoUniqueJwtDecoderFactoryConfig { + + @Bean + JwtDecoderFactory jwtDecoderFactory1() { + return clientRegistration -> JwtDecoderFactoryConfig.getJwtDecoder(); + } + + @Bean + JwtDecoderFactory jwtDecoderFactory2() { + return clientRegistration -> JwtDecoderFactoryConfig.getJwtDecoder(); + } + + } + private static OAuth2AccessTokenResponseClient createOauth2AccessTokenResponseClient() { return request -> { Map additionalParameters = new HashMap<>(); diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index 6027f27b148..1b33b521b21 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -16,12 +16,6 @@ package org.springframework.security.config.web.server; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import org.junit.Rule; import org.junit.Test; import org.openqa.selenium.WebDriver; @@ -34,15 +28,29 @@ import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeReactiveAuthenticationManager; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.TestOAuth2Users; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; @@ -51,9 +59,17 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; - import reactor.core.publisher.Mono; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + /** * @author Rob Winch * @since 5.1 @@ -72,6 +88,12 @@ public class OAuth2LoginTests { .clientSecret("secret") .build(); + private static ClientRegistration google = CommonOAuth2Provider.GOOGLE + .getBuilder("google") + .clientId("client") + .clientSecret("secret") + .build(); + @Test public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() { this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire(); @@ -97,11 +119,6 @@ public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() { static class OAuth2LoginWithMulitpleClientRegistrations { @Bean InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() { - ClientRegistration google = CommonOAuth2Provider.GOOGLE - .getBuilder("google") - .clientId("client") - .clientSecret("secret") - .build(); return new InMemoryReactiveClientRegistrationRepository(github, google); } } @@ -182,6 +199,107 @@ public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { } } + @Test + public void oauth2LoginWhenCustomJwtDecoderFactoryThenUsed() { + this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class, + OAuth2LoginWithJwtDecoderFactoryBeanConfig.class).autowire(); + + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(this.springSecurity) + .build(); + + OAuth2LoginWithJwtDecoderFactoryBeanConfig config = this.spring.getContext() + .getBean(OAuth2LoginWithJwtDecoderFactoryBeanConfig.class); + + OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); + OAuth2AuthorizationCodeAuthenticationToken token = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken); + + ServerAuthenticationConverter converter = config.authenticationConverter; + when(converter.convert(any())).thenReturn(Mono.just(token)); + + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) + .tokenType(accessToken.getTokenType()) + .scopes(accessToken.getScopes()) + .additionalParameters(additionalParameters) + .build(); + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; + when(tokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + + OidcUser user = TestOidcUsers.create(); + ReactiveOAuth2UserService userService = config.userService; + when(userService.loadUser(any())).thenReturn(Mono.just(user)); + + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus().is3xxRedirection(); + + verify(config.jwtDecoderFactory).createDecoder(any()); + } + + @Configuration + static class OAuth2LoginWithJwtDecoderFactoryBeanConfig { + + ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); + + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = + mock(ReactiveOAuth2AccessTokenResponseClient.class); + + ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); + + ReactiveJwtDecoderFactory jwtDecoderFactory = spy(new JwtDecoderFactory()); + + @Bean + public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + // @formatter:off + http + .authorizeExchange() + .anyExchange().authenticated() + .and() + .oauth2Login() + .authenticationConverter(authenticationConverter) + .authenticationManager(authenticationManager()); + return http.build(); + // @formatter:on + } + + private ReactiveAuthenticationManager authenticationManager() { + OidcAuthorizationCodeReactiveAuthenticationManager oidc = + new OidcAuthorizationCodeReactiveAuthenticationManager(tokenResponseClient, userService); + oidc.setJwtDecoderFactory(jwtDecoderFactory()); + return oidc; + } + + @Bean + public ReactiveJwtDecoderFactory jwtDecoderFactory() { + return jwtDecoderFactory; + } + + private static class JwtDecoderFactory implements ReactiveJwtDecoderFactory { + + @Override + public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) { + return getJwtDecoder(); + } + + private ReactiveJwtDecoder getJwtDecoder() { + return token -> { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, "subject"); + claims.put(IdTokenClaimNames.ISS, "http://localhost/issuer"); + claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client")); + claims.put(IdTokenClaimNames.AZP, "client"); + Jwt jwt = new Jwt("id-token", Instant.now(), Instant.now().plusSeconds(3600), + Collections.singletonMap("header1", "value1"), claims); + return Mono.just(jwt); + }; + } + } + } + static class GitHubWebFilter implements WebFilter { @Override diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java index 87c1560f23c..30f5f8a6528 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java @@ -15,10 +15,6 @@ */ package org.springframework.security.oauth2.client.oidc.authentication; -import java.util.Collection; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -31,9 +27,11 @@ import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; +import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; @@ -43,10 +41,17 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.JwtTimestampValidator; import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + import static org.springframework.security.oauth2.jwt.JwtProcessors.withJwkSetUri; /** @@ -80,7 +85,7 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier"; private final OAuth2AccessTokenResponseClient accessTokenResponseClient; private final OAuth2UserService userService; - private final Map jwtDecoders = new ConcurrentHashMap<>(); + private JwtDecoderFactory jwtDecoderFactory = new DefaultJwtDecoderFactory(); private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); /** @@ -174,6 +179,18 @@ public Authentication authenticate(Authentication authentication) throws Authent return authenticationResult; } + /** + * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature verification. + * The factory returns a {@link JwtDecoder} associated to the provided {@link ClientRegistration}. + * + * @since 5.2 + * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature verification + */ + public final void setJwtDecoderFactory(JwtDecoderFactory jwtDecoderFactory) { + Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); + this.jwtDecoderFactory = jwtDecoderFactory; + } + /** * Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OidcUser#getAuthorities()}} * to a new set of authorities which will be associated to the {@link OAuth2LoginAuthenticationToken}. @@ -191,30 +208,41 @@ public boolean supports(Class authentication) { } private OidcIdToken createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) { - JwtDecoder jwtDecoder = getJwtDecoder(clientRegistration); - Jwt jwt = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get( - OidcParameterNames.ID_TOKEN)); + JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); + Jwt jwt; + try { + jwt = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN)); + } catch (JwtException ex) { + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null); + throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); + } OidcIdToken idToken = new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()); - OidcTokenValidator.validateIdToken(idToken, clientRegistration); return idToken; } - private JwtDecoder getJwtDecoder(ClientRegistration clientRegistration) { - JwtDecoder jwtDecoder = this.jwtDecoders.get(clientRegistration.getRegistrationId()); - if (jwtDecoder == null) { - if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) { - OAuth2Error oauth2Error = new OAuth2Error( - MISSING_SIGNATURE_VERIFIER_ERROR_CODE, - "Failed to find a Signature Verifier for Client Registration: '" + - clientRegistration.getRegistrationId() + "'. Check to ensure you have configured the JwkSet URI.", - null - ); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri(); - jwtDecoder = new NimbusJwtDecoder(withJwkSetUri(jwkSetUri).build()); - this.jwtDecoders.put(clientRegistration.getRegistrationId(), jwtDecoder); + private static class DefaultJwtDecoderFactory implements JwtDecoderFactory { + private final Map jwtDecoders = new ConcurrentHashMap<>(); + + @Override + public JwtDecoder createDecoder(ClientRegistration clientRegistration) { + return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> { + if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) { + OAuth2Error oauth2Error = new OAuth2Error( + MISSING_SIGNATURE_VERIFIER_ERROR_CODE, + "Failed to find a Signature Verifier for Client Registration: '" + + clientRegistration.getRegistrationId() + + "'. Check to ensure you have configured the JwkSet URI.", + null + ); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri(); + NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withJwkSetUri(jwkSetUri).build()); + OAuth2TokenValidator jwtValidator = new DelegatingOAuth2TokenValidator<>( + new JwtTimestampValidator(), new OidcIdTokenValidator(clientRegistration)); + jwtDecoder.setJwtValidator(jwtValidator); + return jwtDecoder; + }); } - return jwtDecoder; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java index d53c96f273d..dd811f7e474 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java @@ -26,10 +26,12 @@ import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; @@ -37,8 +39,12 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.JwtTimestampValidator; import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import reactor.core.publisher.Mono; @@ -46,7 +52,6 @@ import java.util.Collection; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; /** * An implementation of an {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth 2.0 Login, @@ -86,7 +91,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); - private Function decoderFactory = new DefaultDecoderFactory(); + private ReactiveJwtDecoderFactory jwtDecoderFactory = new DefaultJwtDecoderFactory(); public OidcAuthorizationCodeReactiveAuthenticationManager( ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient, @@ -138,18 +143,24 @@ public Mono authenticate(Authentication authentication) { return this.accessTokenResponseClient.getTokenResponse(authzRequest) .flatMap(accessTokenResponse -> authenticationResult(authorizationCodeAuthentication, accessTokenResponse)) - .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())); + .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) + .onErrorMap(JwtException.class, e -> { + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null); + throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e); + }); }); } /** - * Provides a way to customize the {@link ReactiveJwtDecoder} given a {@link ClientRegistration} - * @param decoderFactory the {@link Function} used to create {@link ReactiveJwtDecoder} instance. Cannot be null. + * Sets the {@link ReactiveJwtDecoderFactory} used for {@link OidcIdToken} signature verification. + * The factory returns a {@link ReactiveJwtDecoder} associated to the provided {@link ClientRegistration}. + * + * @since 5.2 + * @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} used for {@link OidcIdToken} signature verification */ - void setDecoderFactory( - Function decoderFactory) { - Assert.notNull(decoderFactory, "decoderFactory cannot be null"); - this.decoderFactory = decoderFactory; + public final void setJwtDecoderFactory(ReactiveJwtDecoderFactory jwtDecoderFactory) { + Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); + this.jwtDecoderFactory = jwtDecoderFactory; } private Mono authenticationResult(OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { @@ -183,33 +194,35 @@ private Mono authenticationResult(OAuth2Authoriz } private Mono createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) { - ReactiveJwtDecoder jwtDecoder = this.decoderFactory.apply(clientRegistration); + ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); String rawIdToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN); return jwtDecoder.decode(rawIdToken) - .map(jwt -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims())) - .doOnNext(idToken -> OidcTokenValidator.validateIdToken(idToken, clientRegistration)); + .map(jwt -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims())); } - private static class DefaultDecoderFactory implements Function { + private static class DefaultJwtDecoderFactory implements ReactiveJwtDecoderFactory { private final Map jwtDecoders = new ConcurrentHashMap<>(); @Override - public ReactiveJwtDecoder apply(ClientRegistration clientRegistration) { - ReactiveJwtDecoder jwtDecoder = this.jwtDecoders.get(clientRegistration.getRegistrationId()); - if (jwtDecoder == null) { + public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) { + return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> { if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) { OAuth2Error oauth2Error = new OAuth2Error( MISSING_SIGNATURE_VERIFIER_ERROR_CODE, "Failed to find a Signature Verifier for Client Registration: '" + - clientRegistration.getRegistrationId() + "'. Check to ensure you have configured the JwkSet URI.", + clientRegistration.getRegistrationId() + + "'. Check to ensure you have configured the JwkSet URI.", null ); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - jwtDecoder = new NimbusReactiveJwtDecoder(clientRegistration.getProviderDetails().getJwkSetUri()); - this.jwtDecoders.put(clientRegistration.getRegistrationId(), jwtDecoder); - } - return jwtDecoder; + NimbusReactiveJwtDecoder jwtDecoder = new NimbusReactiveJwtDecoder( + clientRegistration.getProviderDetails().getJwkSetUri()); + OAuth2TokenValidator jwtValidator = new DelegatingOAuth2TokenValidator<>( + new JwtTimestampValidator(), new OidcIdTokenValidator(clientRegistration)); + jwtDecoder.setJwtValidator(jwtValidator); + return jwtDecoder; + }); } } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcTokenValidator.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java similarity index 66% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcTokenValidator.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java index b646015a667..d8852992f4d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcTokenValidator.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.client.oidc.authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import java.net.URL; @@ -27,36 +30,50 @@ import java.util.List; /** + * An {@link OAuth2TokenValidator} responsible for + * validating the claims in an {@link OidcIdToken ID Token}. + * * @author Rob Winch + * @author Joe Grandja * @since 5.1 + * @see OAuth2TokenValidator + * @see Jwt + * @see ID Token Validation */ -final class OidcTokenValidator { - private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; +public final class OidcIdTokenValidator implements OAuth2TokenValidator { + private static final OAuth2Error INVALID_ID_TOKEN_ERROR = new OAuth2Error("invalid_id_token"); + private final ClientRegistration clientRegistration; - static void validateIdToken(OidcIdToken idToken, ClientRegistration clientRegistration) { + public OidcIdTokenValidator(ClientRegistration clientRegistration) { + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + this.clientRegistration = clientRegistration; + } + + @Override + public OAuth2TokenValidatorResult validate(Jwt idToken) { // 3.1.3.7 ID Token Validation // http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // Validate REQUIRED Claims URL issuer = idToken.getIssuer(); if (issuer == null) { - throwInvalidIdTokenException(); + return invalidIdToken(); } String subject = idToken.getSubject(); if (subject == null) { - throwInvalidIdTokenException(); + return invalidIdToken(); } List audience = idToken.getAudience(); if (CollectionUtils.isEmpty(audience)) { - throwInvalidIdTokenException(); + return invalidIdToken(); } Instant expiresAt = idToken.getExpiresAt(); if (expiresAt == null) { - throwInvalidIdTokenException(); + return invalidIdToken(); } Instant issuedAt = idToken.getIssuedAt(); if (issuedAt == null) { - throwInvalidIdTokenException(); + return invalidIdToken(); } // 2. The Issuer Identifier for the OpenID Provider (which is typically obtained during Discovery) @@ -68,21 +85,21 @@ static void validateIdToken(OidcIdToken idToken, ClientRegistration clientRegist // The aud (audience) Claim MAY contain an array with more than one element. // The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience, // or if it contains additional audiences not trusted by the Client. - if (!audience.contains(clientRegistration.getClientId())) { - throwInvalidIdTokenException(); + if (!audience.contains(this.clientRegistration.getClientId())) { + return invalidIdToken(); } // 4. If the ID Token contains multiple audiences, // the Client SHOULD verify that an azp Claim is present. - String authorizedParty = idToken.getAuthorizedParty(); + String authorizedParty = idToken.getClaimAsString(IdTokenClaimNames.AZP); if (audience.size() > 1 && authorizedParty == null) { - throwInvalidIdTokenException(); + return invalidIdToken(); } // 5. If an azp (authorized party) Claim is present, // the Client SHOULD verify that its client_id is the Claim Value. - if (authorizedParty != null && !authorizedParty.equals(clientRegistration.getClientId())) { - throwInvalidIdTokenException(); + if (authorizedParty != null && !authorizedParty.equals(this.clientRegistration.getClientId())) { + return invalidIdToken(); } // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client @@ -92,7 +109,7 @@ static void validateIdToken(OidcIdToken idToken, ClientRegistration clientRegist // 9. The current time MUST be before the time represented by the exp Claim. Instant now = Instant.now(); if (!now.isBefore(expiresAt)) { - throwInvalidIdTokenException(); + return invalidIdToken(); } // 10. The iat Claim can be used to reject tokens that were issued too far away from the current time, @@ -100,7 +117,7 @@ static void validateIdToken(OidcIdToken idToken, ClientRegistration clientRegist // The acceptable range is Client specific. Instant maxIssuedAt = now.plusSeconds(30); if (issuedAt.isAfter(maxIssuedAt)) { - throwInvalidIdTokenException(); + return invalidIdToken(); } // 11. If a nonce value was sent in the Authentication Request, @@ -110,12 +127,10 @@ static void validateIdToken(OidcIdToken idToken, ClientRegistration clientRegist // The precise method for detecting replay attacks is Client specific. // TODO Depends on gh-4442 + return OAuth2TokenValidatorResult.success(); } - private static void throwInvalidIdTokenException() { - OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE); - throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); + private static OAuth2TokenValidatorResult invalidIdToken() { + return OAuth2TokenValidatorResult.failure(INVALID_ID_TOKEN_ERROR); } - - private OidcTokenValidator() {} } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java index 2cc872351aa..31f162f2d59 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java @@ -15,22 +15,12 @@ */ package org.springframework.security.oauth2.client.oidc.authentication; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; - import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; @@ -52,13 +42,20 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; -import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.security.oauth2.jwt.JwtException; + +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyCollection; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; @@ -86,7 +83,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { @Before @SuppressWarnings("unchecked") - public void setUp() throws Exception { + public void setUp() { this.clientRegistration = clientRegistration().clientId("client1").build(); this.authorizationRequest = request().scope("openid", "profile", "email").build(); this.authorizationResponse = success().build(); @@ -112,6 +109,12 @@ public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null); } + @Test + public void setJwtDecoderFactoryWhenNullThenThrowIllegalArgumentException() { + this.exception.expect(IllegalArgumentException.class); + this.authenticationProvider.setJwtDecoderFactory(null); + } + @Test public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() { this.exception.expect(IllegalArgumentException.class); @@ -202,139 +205,20 @@ public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationExceptio } @Test - public void authenticateWhenIdTokenIssuerClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); - - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.SUB, "subject1"); - - this.setUpIdToken(claims); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - } - - @Test - public void authenticateWhenIdTokenSubjectClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); - - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.ISS, "https://provider.com"); - - this.setUpIdToken(claims); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - } - - @Test - public void authenticateWhenIdTokenAudienceClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); - - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.ISS, "https://provider.com"); - claims.put(IdTokenClaimNames.SUB, "subject1"); - - this.setUpIdToken(claims); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - } - - @Test - public void authenticateWhenIdTokenAudienceClaimDoesNotContainClientIdThenThrowOAuth2AuthenticationException() throws Exception { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); - - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.ISS, "https://provider.com"); - claims.put(IdTokenClaimNames.SUB, "subject1"); - claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client")); - - this.setUpIdToken(claims); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - } - - @Test - public void authenticateWhenIdTokenMultipleAudienceClaimAndAuthorizedPartyClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); - - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.ISS, "https://provider.com"); - claims.put(IdTokenClaimNames.SUB, "subject1"); - claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2")); - - this.setUpIdToken(claims); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - } - - @Test - public void authenticateWhenIdTokenAuthorizedPartyClaimNotEqualToClientIdThenThrowOAuth2AuthenticationException() throws Exception { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); - - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.ISS, "https://provider.com"); - claims.put(IdTokenClaimNames.SUB, "subject1"); - claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2")); - claims.put(IdTokenClaimNames.AZP, "other-client"); - - this.setUpIdToken(claims); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - } - - @Test - public void authenticateWhenIdTokenExpiresAtIsBeforeNowThenThrowOAuth2AuthenticationException() throws Exception { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); - - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.ISS, "https://provider.com"); - claims.put(IdTokenClaimNames.SUB, "subject1"); - claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2")); - claims.put(IdTokenClaimNames.AZP, "client1"); - - Instant issuedAt = Instant.now().minusSeconds(10); - Instant expiresAt = Instant.from(issuedAt).plusSeconds(5); - - this.setUpIdToken(claims, issuedAt, expiresAt); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - } - - @Test - public void authenticateWhenIdTokenIssuedAtIsAfterMaxIssuedAtThenThrowOAuth2AuthenticationException() throws Exception { + public void authenticateWhenIdTokenValidationErrorThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); + this.exception.expectMessage(containsString("[invalid_id_token] ID Token Validation Error")); - Map claims = new HashMap<>(); - claims.put(IdTokenClaimNames.ISS, "https://provider.com"); - claims.put(IdTokenClaimNames.SUB, "subject1"); - claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2")); - claims.put(IdTokenClaimNames.AZP, "client1"); - - Instant issuedAt = Instant.now().plusSeconds(35); - Instant expiresAt = Instant.from(issuedAt).plusSeconds(60); - - this.setUpIdToken(claims, issuedAt, expiresAt); + JwtDecoder jwtDecoder = mock(JwtDecoder.class); + when(jwtDecoder.decode(anyString())).thenThrow(new JwtException("ID Token Validation Error")); + this.authenticationProvider.setJwtDecoderFactory(registration -> jwtDecoder); this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); } @Test - public void authenticateWhenLoginSuccessThenReturnAuthentication() throws Exception { + public void authenticateWhenLoginSuccessThenReturnAuthentication() { Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://provider.com"); claims.put(IdTokenClaimNames.SUB, "subject1"); @@ -363,7 +247,7 @@ public void authenticateWhenLoginSuccessThenReturnAuthentication() throws Except } @Test - public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() throws Exception { + public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() { Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://provider.com"); claims.put(IdTokenClaimNames.SUB, "subject1"); @@ -392,7 +276,7 @@ public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() th // gh-5368 @Test - public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() throws Exception { + public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() { Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://provider.com"); claims.put(IdTokenClaimNames.SUB, "subject1"); @@ -414,13 +298,13 @@ public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToU this.accessTokenResponse.getAdditionalParameters()); } - private void setUpIdToken(Map claims) throws Exception { + private void setUpIdToken(Map claims) { Instant issuedAt = Instant.now(); Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600); this.setUpIdToken(claims, issuedAt, expiresAt); } - private void setUpIdToken(Map claims, Instant issuedAt, Instant expiresAt) throws Exception { + private void setUpIdToken(Map claims, Instant issuedAt, Instant expiresAt) { Map headers = new HashMap<>(); headers.put("alg", "RS256"); @@ -428,8 +312,7 @@ private void setUpIdToken(Map claims, Instant issuedAt, Instant JwtDecoder jwtDecoder = mock(JwtDecoder.class); when(jwtDecoder.decode(anyString())).thenReturn(idToken); - ReflectionTestUtils.setField(this.authenticationProvider, - "jwtDecoders", Collections.singletonMap("registration-id", jwtDecoder)); + this.authenticationProvider.setJwtDecoderFactory(registration -> jwtDecoder); } private OAuth2AccessTokenResponse accessTokenSuccessResponse() { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java index 8cfcb628b31..7b6f534950c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java @@ -44,6 +44,7 @@ import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import reactor.core.publisher.Mono; @@ -53,9 +54,7 @@ import java.util.HashMap; import java.util.Map; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -105,6 +104,12 @@ public void constructorWhenNullUserServiceThenIllegalArgumentException() { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setJwtDecoderFactoryWhenNullThenIllegalArgumentException() { + assertThatThrownBy(() -> this.manager.setJwtDecoderFactory(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void authenticateWhenNoSubscriptionThenDoesNothing() { // we didn't do anything because it should cause a ClassCastException (as verified below) @@ -139,6 +144,22 @@ public void authenticationWhenStateDoesNotMatchThenOAuth2AuthenticationException .isInstanceOf(OAuth2AuthenticationException.class); } + @Test + public void authenticateWhenIdTokenValidationErrorThenOAuth2AuthenticationException() { + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + + when(this.jwtDecoder.decode(any())).thenThrow(new JwtException("ID Token Validation Error")); + this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); + + assertThatThrownBy(() -> this.manager.authenticate(loginToken()).block()) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("[invalid_id_token] ID Token Validation Error"); + } + @Test public void authenticationWhenOAuth2UserNotFoundThenEmpty() { OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") @@ -157,7 +178,7 @@ public void authenticationWhenOAuth2UserNotFoundThenEmpty() { when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); when(this.userService.loadUser(any())).thenReturn(Mono.empty()); when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setDecoderFactory(c -> this.jwtDecoder); + this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); assertThat(this.manager.authenticate(loginToken()).block()).isNull(); } @@ -180,7 +201,7 @@ public void authenticationWhenOAuth2UserFoundThenSuccess() { DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken); when(this.userService.loadUser(any())).thenReturn(Mono.just(user)); when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setDecoderFactory(c -> this.jwtDecoder); + this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block(); @@ -209,7 +230,7 @@ public void authenticationWhenRefreshTokenThenRefreshTokenInAuthorizedClient() { DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken); when(this.userService.loadUser(any())).thenReturn(Mono.just(user)); when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setDecoderFactory(c -> this.jwtDecoder); + this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block(); @@ -245,7 +266,7 @@ public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToU ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user)); when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setDecoderFactory(c -> this.jwtDecoder); + this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); this.manager.authenticate(loginToken()).block(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java new file mode 100644 index 00000000000..088bfbe900f --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java @@ -0,0 +1,186 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.oidc.authentication; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; +import org.springframework.security.oauth2.jwt.Jwt; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Rob Winch + * @author Joe Grandja + * @since 5.1 + */ +public class OidcIdTokenValidatorTests { + private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration(); + private Map headers = new HashMap<>(); + private Map claims = new HashMap<>(); + private Instant issuedAt = Instant.now(); + private Instant expiresAt = this.issuedAt.plusSeconds(3600); + + @Before + public void setup() { + this.headers.put("alg", JwsAlgorithms.RS256); + this.claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); + this.claims.put(IdTokenClaimNames.SUB, "rob"); + this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client-id")); + } + + @Test + public void validateIdTokenWhenValidThenNoErrors() { + assertThat(this.validateIdToken()).isEmpty(); + } + + @Test + public void validateIdTokenWhenIssuerNullThenHasErrors() { + this.claims.remove(IdTokenClaimNames.ISS); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenSubNullThenHasErrors() { + this.claims.remove(IdTokenClaimNames.SUB); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenAudNullThenHasErrors() { + this.claims.remove(IdTokenClaimNames.AUD); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenIssuedAtNullThenHasErrors() { + this.issuedAt = null; + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenExpiresAtNullThenHasErrors() { + this.expiresAt = null; + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenAudMultipleAndAzpNullThenHasErrors() { + this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other")); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenAzpNotClientIdThenHasErrors() { + this.claims.put(IdTokenClaimNames.AZP, "other"); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenMultipleAudAzpClientIdThenNoErrors() { + this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other")); + this.claims.put(IdTokenClaimNames.AZP, "client-id"); + assertThat(this.validateIdToken()).isEmpty(); + } + + @Test + public void validateIdTokenWhenMultipleAudAzpNotClientIdThenHasErrors() { + this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id-1", "client-id-2")); + this.claims.put(IdTokenClaimNames.AZP, "other-client"); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenAudNotClientIdThenHasErrors() { + this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client")); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenExpiredThenHasErrors() { + this.issuedAt = Instant.now().minus(Duration.ofMinutes(1)); + this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1)); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenIssuedAtWayInFutureThenHasErrors() { + this.issuedAt = Instant.now().plus(Duration.ofMinutes(5)); + this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1)); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + @Test + public void validateIdTokenWhenExpiresAtBeforeNowThenHasErrors() { + this.issuedAt = Instant.now().minusSeconds(10); + this.expiresAt = Instant.from(this.issuedAt).plusSeconds(5); + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getErrorCode) + .contains("invalid_id_token"); + } + + private Collection validateIdToken() { + Jwt idToken = new Jwt("token123", this.issuedAt, this.expiresAt, this.headers, this.claims); + OidcIdTokenValidator validator = new OidcIdTokenValidator(this.registration.build()); + return validator.validate(idToken).getErrors(); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcTokenValidatorTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcTokenValidatorTests.java deleted file mode 100644 index cb95ada199e..00000000000 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcTokenValidatorTests.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright 2002-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.oidc.authentication; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; - -import java.time.Duration; -import java.time.Instant; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -import static org.assertj.core.api.Assertions.assertThatCode; - -/** - * @author Rob Winch - * @since 5.1 - */ -public class OidcTokenValidatorTests { - private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration(); - - private Map claims = new HashMap<>(); - private Instant issuedAt = Instant.now(); - private Instant expiresAt = Instant.now().plusSeconds(3600); - - @Before - public void setup() { - this.claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); - this.claims.put(IdTokenClaimNames.SUB, "rob"); - this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id")); - } - - @Test - public void validateIdTokenWhenValidThenNoException() { - assertThatCode(() -> validateIdToken()) - .doesNotThrowAnyException(); - } - - @Test - public void validateIdTokenWhenIssuerNullThenException() { - this.claims.remove(IdTokenClaimNames.ISS); - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void validateIdTokenWhenSubNullThenException() { - this.claims.remove(IdTokenClaimNames.SUB); - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void validateIdTokenWhenAudNullThenException() { - this.claims.remove(IdTokenClaimNames.AUD); - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void validateIdTokenWhenIssuedAtNullThenException() { - this.issuedAt = null; - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void validateIdTokenWhenExpiresAtNullThenException() { - this.expiresAt = null; - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void validateIdTokenWhenAudMultipleAndAzpNullThenException() { - this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other")); - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void validateIdTokenWhenAzpNotClientIdThenException() { - this.claims.put(IdTokenClaimNames.AZP, "other"); - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void validateIdTokenWhenMulitpleAudAzpClientIdThenNoException() { - this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other")); - this.claims.put(IdTokenClaimNames.AZP, "client-id"); - assertThatCode(() -> validateIdToken()) - .doesNotThrowAnyException(); - } - - @Test - public void validateIdTokenWhenExpiredThenException() { - this.issuedAt = Instant.now().minus(Duration.ofMinutes(1)); - this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1)); - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - @Test - public void validateIdTokenWhenIssuedAtWayInFutureThenException() { - this.issuedAt = Instant.now().plus(Duration.ofMinutes(5)); - this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1)); - assertThatCode(() -> validateIdToken()) - .isInstanceOf(OAuth2AuthenticationException.class); - } - - private void validateIdToken() { - OidcIdToken token = new OidcIdToken("token123", this.issuedAt, this.expiresAt, this.claims); - OidcTokenValidator.validateIdToken(token, this.registration.build()); - } - -} diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationRequests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationRequests.java index 6b10c24343a..9386e3eeb74 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationRequests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationRequests.java @@ -32,6 +32,7 @@ public static OAuth2AuthorizationRequest.Builder request() { return OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/login/oauth/authorize") .clientId(clientId) + .scope("openid") .redirectUri("https://example.com/authorize/oauth2/code/registration-id") .state("state") .additionalParameters(additionalParameters); diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java new file mode 100644 index 00000000000..0aad2cef5bd --- /dev/null +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.core.oidc.user; + +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; + +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * @author Joe Grandja + */ +public class TestOidcUsers { + + public static DefaultOidcUser create() { + List roles = AuthorityUtils.createAuthorityList("ROLE_USER"); + return new DefaultOidcUser(roles, idToken()); + } + + private static OidcIdToken idToken() { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, "subject"); + claims.put(IdTokenClaimNames.ISS, "http://localhost/issuer"); + claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client")); + claims.put(IdTokenClaimNames.AZP, "client"); + return new OidcIdToken("id-token", Instant.now(), Instant.now().plusSeconds(3600), claims); + } +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderFactory.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderFactory.java new file mode 100644 index 00000000000..b33459962d2 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderFactory.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +/** + * A factory for {@link JwtDecoder}(s). + * This factory should be supplied with a type that provides + * contextual information used to create a specific {@code JwtDecoder}. + * + * @author Joe Grandja + * @since 5.2 + * @see JwtDecoder + * + * @param The type that provides contextual information used to create a specific {@code JwtDecoder}. + */ +public interface JwtDecoderFactory { + + /** + * Creates a {@code JwtDecoder} using the supplied "contextual" type. + * + * @param context the type that provides contextual information + * @return a {@link JwtDecoder} + */ + JwtDecoder createDecoder(C context); + +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoderFactory.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoderFactory.java new file mode 100644 index 00000000000..3a72f246b46 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoderFactory.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.jwt; + +/** + * A factory for {@link ReactiveJwtDecoder}(s). + * This factory should be supplied with a type that provides + * contextual information used to create a specific {@code ReactiveJwtDecoder}. + * + * @author Joe Grandja + * @since 5.2 + * @see ReactiveJwtDecoder + * + * @param The type that provides contextual information used to create a specific {@code ReactiveJwtDecoder}. + */ +public interface ReactiveJwtDecoderFactory { + + /** + * Creates a {@code ReactiveJwtDecoder} using the supplied "contextual" type. + * + * @param context the type that provides contextual information + * @return a {@link ReactiveJwtDecoder} + */ + ReactiveJwtDecoder createDecoder(C context); + +}