Skip to content

Extract OidcTokenValidator to an OAuth2TokenValidator #6298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.config.annotation.web.configurers.oauth2.client;

import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.security.authentication.AuthenticationProvider;
Expand Down Expand Up @@ -55,6 +56,7 @@
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
Expand Down Expand Up @@ -488,6 +490,10 @@ public void init(B http) throws Exception {

OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider =
new OidcAuthorizationCodeAuthenticationProvider(accessTokenResponseClient, oidcUserService);
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = this.getJwtDecoderFactoryBean();
if (jwtDecoderFactory != null) {
oidcAuthorizationCodeAuthenticationProvider.setJwtDecoderFactory(jwtDecoderFactory);
}
if (userAuthoritiesMapper != null) {
oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
}
Expand Down Expand Up @@ -541,6 +547,19 @@ protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingU
return new AntPathRequestMatcher(loginProcessingUrl);
}

@SuppressWarnings("unchecked")
private JwtDecoderFactory<ClientRegistration> getJwtDecoderFactoryBean() {
ResolvableType type = ResolvableType.forClassWithGenerics(JwtDecoderFactory.class, ClientRegistration.class);
String[] names = this.getBuilder().getSharedObject(ApplicationContext.class).getBeanNamesForType(type);
if (names.length > 1) {
throw new NoUniqueBeanDefinitionException(type, names);
}
if (names.length == 1) {
return (JwtDecoderFactory<ClientRegistration>) this.getBuilder().getSharedObject(ApplicationContext.class).getBean(names[0]);
}
return null;
}

private GrantedAuthoritiesMapper getGrantedAuthoritiesMapper() {
GrantedAuthoritiesMapper grantedAuthoritiesMapper =
this.getBuilder().getSharedObject(GrantedAuthoritiesMapper.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,6 @@

package org.springframework.security.config.web.server;

import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.match;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.notMatch;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.UUID;

import reactor.core.publisher.Mono;
import reactor.util.context.Context;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.Ordered;
Expand All @@ -55,6 +33,8 @@
import org.springframework.security.authorization.ReactiveAuthorizationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
Expand All @@ -80,6 +60,7 @@
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter;
import org.springframework.security.oauth2.server.resource.authentication.JwtReactiveAuthenticationManager;
import org.springframework.security.oauth2.server.resource.authentication.ReactiveJwtAuthenticationConverterAdapter;
Expand All @@ -92,6 +73,7 @@
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.ServerAuthenticationEntryPoint;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilter;
import org.springframework.security.web.server.authentication.AuthenticationWebFilter;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint;
Expand Down Expand Up @@ -159,9 +141,27 @@
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilter;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;

import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.match;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.notMatch;

/**
* A {@link ServerHttpSecurity} is similar to Spring Security's {@code HttpSecurity} but for WebFlux.
Expand Down Expand Up @@ -618,7 +618,14 @@ private ReactiveAuthenticationManager createDefault() {
boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent(
"org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader());
if (oidcAuthenticationProviderEnabled) {
OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService());
OidcAuthorizationCodeReactiveAuthenticationManager oidc =
new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService());
ResolvableType type = ResolvableType.forClassWithGenerics(
ReactiveJwtDecoderFactory.class, ClientRegistration.class);
ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = getBeanOrNull(type);
if (jwtDecoderFactory != null) {
oidc.setJwtDecoderFactory(jwtDecoderFactory);
}
result = new DelegatingReactiveAuthenticationManager(oidc, result);
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.PropertyAccessorFactory;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
Expand All @@ -50,7 +51,6 @@
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
Expand All @@ -66,6 +66,7 @@
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
Expand All @@ -81,6 +82,7 @@
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -369,8 +371,7 @@ public void oauth2LoginWithCustomLoginPageThenRedirectCustomLoginPage() throws E
@Test
public void oidcLogin() throws Exception {
// setup application context
loadConfig(OAuth2LoginConfig.class);
registerJwtDecoder();
loadConfig(OAuth2LoginConfig.class, JwtDecoderFactoryConfig.class);

// setup authorization request
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid");
Expand All @@ -396,8 +397,7 @@ public void oidcLogin() throws Exception {
@Test
public void oidcLoginCustomWithConfigurer() throws Exception {
// setup application context
loadConfig(OAuth2LoginConfigCustomWithConfigurer.class);
registerJwtDecoder();
loadConfig(OAuth2LoginConfigCustomWithConfigurer.class, JwtDecoderFactoryConfig.class);

// setup authorization request
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid");
Expand All @@ -423,8 +423,7 @@ public void oidcLoginCustomWithConfigurer() throws Exception {
@Test
public void oidcLoginCustomWithBeanRegistration() throws Exception {
// setup application context
loadConfig(OAuth2LoginConfigCustomWithBeanRegistration.class);
registerJwtDecoder();
loadConfig(OAuth2LoginConfigCustomWithBeanRegistration.class, JwtDecoderFactoryConfig.class);

// setup authorization request
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid");
Expand All @@ -447,6 +446,15 @@ public void oidcLoginCustomWithBeanRegistration() throws Exception {
assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OIDC_USER");
}

@Test
public void oidcLoginCustomWithNoUniqueJwtDecoderFactory() {
assertThatThrownBy(() -> loadConfig(OAuth2LoginConfig.class, NoUniqueJwtDecoderFactoryConfig.class))
.hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class)
.hasMessageContaining("No qualifying bean of type " +
"'org.springframework.security.oauth2.jwt.JwtDecoderFactory<org.springframework.security.oauth2.client.registration.ClientRegistration>' " +
"available: expected single matching bean but found 2: jwtDecoderFactory1,jwtDecoderFactory2");
}

private void loadConfig(Class<?>... configs) {
AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext();
applicationContext.register(configs);
Expand All @@ -455,25 +463,6 @@ private void loadConfig(Class<?>... configs) {
this.context = applicationContext;
}

private void registerJwtDecoder() {
JwtDecoder decoder = token -> {
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.SUB, "sub123");
claims.put(IdTokenClaimNames.ISS, "http://localhost/iss");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d"));
claims.put(IdTokenClaimNames.AZP, "clientId");
return new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600),
Collections.singletonMap("header1", "value1"), claims);
};
this.springSecurityFilterChain.getFilters("/login/oauth2/code/google").stream()
.filter(OAuth2LoginAuthenticationFilter.class::isInstance)
.findFirst()
.ifPresent(filter -> PropertyAccessorFactory.forDirectFieldAccess(filter)
.setPropertyValue(
"authenticationManager.providers[2].jwtDecoders['google']",
decoder));
}

private OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(String... scopes) {
return this.createOAuth2AuthorizationRequest(GOOGLE_CLIENT_REGISTRATION, scopes);
}
Expand Down Expand Up @@ -632,6 +621,43 @@ HttpSessionOAuth2AuthorizationRequestRepository oauth2AuthorizationRequestReposi
}
}

@Configuration
static class JwtDecoderFactoryConfig {

@Bean
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory() {
return clientRegistration -> getJwtDecoder();
}

private static JwtDecoder getJwtDecoder() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to return a mock so that testing is easier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could return a mock(JwtDecoder.class) but mocking the return value of decode() will produce the same amount of code so doesn't really make it easier. However, it might make more sense to do it this way so will apply the change.

Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.SUB, "sub123");
claims.put(IdTokenClaimNames.ISS, "http://localhost/iss");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d"));
claims.put(IdTokenClaimNames.AZP, "clientId");
Jwt jwt = new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600),
Collections.singletonMap("header1", "value1"), claims);
JwtDecoder jwtDecoder = mock(JwtDecoder.class);
when(jwtDecoder.decode(any())).thenReturn(jwt);
return jwtDecoder;
}
}

@Configuration
static class NoUniqueJwtDecoderFactoryConfig {

@Bean
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory1() {
return clientRegistration -> JwtDecoderFactoryConfig.getJwtDecoder();
}

@Bean
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory2() {
return clientRegistration -> JwtDecoderFactoryConfig.getJwtDecoder();
}

}

private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> createOauth2AccessTokenResponseClient() {
return request -> {
Map<String, Object> additionalParameters = new HashMap<>();
Expand Down
Loading