From 0bbf58f894bf3e04b9019827d00d7b3823c85217 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 16 Jun 2020 13:05:33 -0400 Subject: [PATCH 1/3] Register RestOperations @Bean for OAuth 2.0 Client Closes gh-5607 --- .../OAuth2ClientConfiguration.java | 50 +++++- .../oauth2/client/OAuth2ClientConfigurer.java | 6 +- .../client/OAuth2ClientConfigurerUtils.java | 9 +- .../oauth2/client/OAuth2LoginConfigurer.java | 25 ++- .../http/AuthenticationConfigBuilder.java | 36 +++-- .../OAuth2ClientBeanDefinitionParser.java | 2 + .../http/OAuth2LoginBeanDefinitionParser.java | 22 ++- ...2AuthorizedClientManagerPostProcessor.java | 151 ++++++++++++++++++ .../oauth2/client/OAuth2ClientBeanNames.java | 32 ++++ ...uth2ClientRestOperationsPostProcessor.java | 91 +++++++++++ .../OAuth2ClientConfigurationTests.java | 115 ++++++++++++- .../client/OAuth2ClientConfigurerTests.java | 96 ++++++++++- .../client/OAuth2LoginConfigurerTests.java | 92 +++++++++-- .../OAuth2ResourceServerConfigurerTests.java | 2 + ...OAuth2ClientBeanDefinitionParserTests.java | 34 ++++ .../OAuth2LoginBeanDefinitionParserTests.java | 49 +++++- ...nitionParserTests-CustomRestOperations.xml | 55 +++++++ ...onParserTests-WithCustomRestOperations.xml | 44 +++++ .../registration/ClientRegistrations.java | 29 ++-- .../registration/ClientRegistrationsTest.java | 6 + 20 files changed, 886 insertions(+), 60 deletions(-) create mode 100644 config/src/main/java/org/springframework/security/config/oauth2/client/DefaultOAuth2AuthorizedClientManagerPostProcessor.java create mode 100644 config/src/main/java/org/springframework/security/config/oauth2/client/OAuth2ClientBeanNames.java create mode 100644 config/src/main/java/org/springframework/security/config/oauth2/client/OAuth2ClientRestOperationsPostProcessor.java create mode 100644 config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomRestOperations.xml create mode 100644 config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-WithCustomRestOperations.xml diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index b29212d79ef..2a607ad5475 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,27 @@ package org.springframework.security.config.annotation.web.configuration; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; +import org.springframework.security.config.oauth2.client.DefaultOAuth2AuthorizedClientManagerPostProcessor; +import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; +import org.springframework.security.config.oauth2.client.OAuth2ClientRestOperationsPostProcessor; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.ClientRegistrations; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.util.ClassUtils; +import org.springframework.web.client.RestOperations; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; @@ -46,19 +54,35 @@ * @since 5.1 * @see OAuth2ImportSelector */ -@Import(OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class) +@Import({OAuth2ClientConfiguration.OAuth2ClientSecurityConfiguration.class, + OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class, + OAuth2ClientConfiguration.OAuth2ClientRegistrationsConfiguration.class}) final class OAuth2ClientConfiguration { + @Configuration(proxyBeanMethods = false) + static class OAuth2ClientSecurityConfiguration { + + @Bean + BeanDefinitionRegistryPostProcessor oauth2ClientRestOperationsPostProcessor() { + return new OAuth2ClientRestOperationsPostProcessor(); + } + + @Bean + BeanDefinitionRegistryPostProcessor defaultOAuth2AuthorizedClientManagerPostProcessor() { + return new DefaultOAuth2AuthorizedClientManagerPostProcessor(); + } + } + static class OAuth2ClientWebMvcImportSelector implements ImportSelector { @Override public String[] selectImports(AnnotationMetadata importingClassMetadata) { boolean webmvcPresent = ClassUtils.isPresent( - "org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader()); + "org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader()); return webmvcPresent ? - new String[] { "org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" } : - new String[] {}; + new String[] { "org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" } : + new String[] {}; } } @@ -91,23 +115,33 @@ public void addArgumentResolvers(List argumentRes } @Autowired(required = false) - public void setClientRegistrationRepository(List clientRegistrationRepositories) { + void setClientRegistrationRepository(List clientRegistrationRepositories) { if (clientRegistrationRepositories.size() == 1) { this.clientRegistrationRepository = clientRegistrationRepositories.get(0); } } @Autowired(required = false) - public void setAuthorizedClientRepository(List authorizedClientRepositories) { + void setAuthorizedClientRepository(List authorizedClientRepositories) { if (authorizedClientRepositories.size() == 1) { this.authorizedClientRepository = authorizedClientRepositories.get(0); } } @Autowired - public void setAccessTokenResponseClient( + void setAccessTokenResponseClient( Optional> accessTokenResponseClient) { accessTokenResponseClient.ifPresent(client -> this.accessTokenResponseClient = client); } } + + @Configuration(proxyBeanMethods = false) + static class OAuth2ClientRegistrationsConfiguration { + + @Autowired + @Qualifier(OAuth2ClientBeanNames.REST_OPERATIONS) + void configure(RestOperations restOperations) { + ClientRegistrations.setRestOperations(restOperations); + } + } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index af2f56e0cd5..d65f9190714 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -268,7 +268,11 @@ private OAuth2AccessTokenResponseClient get if (this.accessTokenResponseClient != null) { return this.accessTokenResponseClient; } - return new DefaultAuthorizationCodeTokenResponseClient(); + DefaultAuthorizationCodeTokenResponseClient authorizationCodeTokenResponseClient = + new DefaultAuthorizationCodeTokenResponseClient(); + authorizationCodeTokenResponseClient.setRestOperations( + OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder())); + return authorizationCodeTokenResponseClient; } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java index 046c6077399..c24ed9c65db 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,12 +20,14 @@ import org.springframework.context.ApplicationContext; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; +import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.StringUtils; +import org.springframework.web.client.RestOperations; import java.util.Map; @@ -96,4 +98,9 @@ private static > OAuth2AuthorizedClientService } return (!authorizedClientServiceMap.isEmpty() ? authorizedClientServiceMap.values().iterator().next() : null); } + + static > RestOperations getRestOperationsBean(B builder) { + return builder.getSharedObject(ApplicationContext.class).getBean( + OAuth2ClientBeanNames.REST_OPERATIONS, RestOperations.class); + } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 823de21892c..bd93f07175f 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -503,7 +503,11 @@ public void init(B http) throws Exception { OAuth2AccessTokenResponseClient accessTokenResponseClient = this.tokenEndpointConfig.accessTokenResponseClient; if (accessTokenResponseClient == null) { - accessTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); + DefaultAuthorizationCodeTokenResponseClient authorizationCodeTokenResponseClient = + new DefaultAuthorizationCodeTokenResponseClient(); + authorizationCodeTokenResponseClient.setRestOperations( + OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder())); + accessTokenResponseClient = authorizationCodeTokenResponseClient; } OAuth2UserService oauth2UserService = getOAuth2UserService(); @@ -619,7 +623,11 @@ private OAuth2UserService getOidcUserService() { ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OidcUserRequest.class, OidcUser.class); OAuth2UserService bean = getBeanOrNull(type); if (bean == null) { - return new OidcUserService(); + DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + userService.setRestOperations(OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder())); + OidcUserService oidcUserService = new OidcUserService(); + oidcUserService.setOauth2UserService(userService); + return oidcUserService; } return bean; @@ -632,13 +640,18 @@ private OAuth2UserService getOAuth2UserService() ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OAuth2UserRequest.class, OAuth2User.class); OAuth2UserService bean = getBeanOrNull(type); if (bean == null) { + DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + userService.setRestOperations(OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder())); if (!this.userInfoEndpointConfig.customUserTypes.isEmpty()) { List> userServices = new ArrayList<>(); - userServices.add(new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes)); - userServices.add(new DefaultOAuth2UserService()); + CustomUserTypesOAuth2UserService customUserTypesUserService = + new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes); + customUserTypesUserService.setRestOperations(OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder())); + userServices.add(customUserTypesUserService); + userServices.add(userService); return new DelegatingOAuth2UserService<>(userServices); } else { - return new DefaultOAuth2UserService(); + return userService; } } diff --git a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java index 6de6a2f7115..6eae29bbfdc 100644 --- a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java @@ -15,18 +15,8 @@ */ package org.springframework.security.config.http; -import java.security.SecureRandom; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import javax.servlet.http.HttpServletRequest; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.w3c.dom.Element; - import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; @@ -41,6 +31,8 @@ import org.springframework.security.authentication.RememberMeAuthenticationProvider; import org.springframework.security.config.BeanIds; import org.springframework.security.config.Elements; +import org.springframework.security.config.oauth2.client.DefaultOAuth2AuthorizedClientManagerPostProcessor; +import org.springframework.security.config.oauth2.client.OAuth2ClientRestOperationsPostProcessor; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.mapping.SimpleAttributes2GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.SimpleMappableAttributesRetriever; @@ -65,6 +57,15 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; +import org.w3c.dom.Element; + +import javax.servlet.http.HttpServletRequest; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; import static org.springframework.security.config.http.SecurityFilters.ANONYMOUS_FILTER; import static org.springframework.security.config.http.SecurityFilters.BASIC_AUTH_FILTER; @@ -169,6 +170,7 @@ final class AuthenticationConfigBuilder { private BeanDefinition authorizationRequestRedirectFilter; private BeanDefinition authorizationCodeGrantFilter; private BeanReference authorizationCodeAuthenticationProviderRef; + private boolean oauth2ClientPostProcessorsRegistered; private final List authenticationProviders = new ManagedList<>(); private final Map defaultDeniedHandlerMappings = new ManagedMap<>(); @@ -312,6 +314,8 @@ void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authMa pc.registerBeanComponent(new BeanComponentDefinition( oauth2LoginOidcAuthProvider, oauth2LoginOidcAuthProviderId)); oauth2LoginOidcAuthenticationProviderRef = new RuntimeBeanReference(oauth2LoginOidcAuthProviderId); + + registerOAuth2ClientPostProcessors(); } void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenticationManager) { @@ -342,6 +346,18 @@ void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenti this.pc.registerBeanComponent(new BeanComponentDefinition( authorizationCodeAuthenticationProvider, authorizationCodeAuthenticationProviderId)); this.authorizationCodeAuthenticationProviderRef = new RuntimeBeanReference(authorizationCodeAuthenticationProviderId); + + registerOAuth2ClientPostProcessors(); + } + + private void registerOAuth2ClientPostProcessors() { + if (!this.oauth2ClientPostProcessorsRegistered) { + this.pc.getReaderContext().registerWithGeneratedName( + new RootBeanDefinition(OAuth2ClientRestOperationsPostProcessor.class)); + this.pc.getReaderContext().registerWithGeneratedName( + new RootBeanDefinition(DefaultOAuth2AuthorizedClientManagerPostProcessor.class)); + oauth2ClientPostProcessorsRegistered = true; + } } void createOpenIDLoginFilter(BeanReference sessionStrategy, BeanReference authManager) { diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java index 269143ede08..991dc65ed58 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java @@ -22,6 +22,7 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.ParserContext; +import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter; @@ -153,6 +154,7 @@ private BeanMetadataElement getAccessTokenResponseClient(Element element) { } else { accessTokenResponseClient = BeanDefinitionBuilder.rootBeanDefinition( "org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient") + .addPropertyReference("restOperations", OAuth2ClientBeanNames.REST_OPERATIONS) .getBeanDefinition(); } return accessTokenResponseClient; diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java index 0d2c7c44106..b7ca6c7de6a 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java @@ -15,13 +15,6 @@ */ package org.springframework.security.config.http; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanDefinition; @@ -37,6 +30,7 @@ import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.Elements; +import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationProvider; @@ -66,6 +60,13 @@ import org.springframework.web.accept.HeaderContentNegotiationStrategy; import org.w3c.dom.Element; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + /** * @author Ruby Hartono * @since 5.3 @@ -320,8 +321,13 @@ private BeanMetadataElement getOidcUserService(Element element) { if (!StringUtils.isEmpty(oidcUserServiceRef)) { oidcUserService = new RuntimeBeanReference(oidcUserServiceRef); } else { + BeanMetadataElement oauth2UserService = BeanDefinitionBuilder + .rootBeanDefinition("org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService") + .addPropertyReference("restOperations", OAuth2ClientBeanNames.REST_OPERATIONS) + .getBeanDefinition(); oidcUserService = BeanDefinitionBuilder .rootBeanDefinition("org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService") + .addPropertyValue("oauth2UserService", oauth2UserService) .getBeanDefinition(); } return oidcUserService; @@ -335,6 +341,7 @@ private BeanMetadataElement getOAuth2UserService(Element element) { } else { oauth2UserService = BeanDefinitionBuilder .rootBeanDefinition("org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService") + .addPropertyReference("restOperations", OAuth2ClientBeanNames.REST_OPERATIONS) .getBeanDefinition(); } return oauth2UserService; @@ -348,6 +355,7 @@ private BeanMetadataElement getAccessTokenResponseClient(Element element) { } else { accessTokenResponseClient = BeanDefinitionBuilder.rootBeanDefinition( "org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient") + .addPropertyReference("restOperations", OAuth2ClientBeanNames.REST_OPERATIONS) .getBeanDefinition(); } return accessTokenResponseClient; diff --git a/config/src/main/java/org/springframework/security/config/oauth2/client/DefaultOAuth2AuthorizedClientManagerPostProcessor.java b/config/src/main/java/org/springframework/security/config/oauth2/client/DefaultOAuth2AuthorizedClientManagerPostProcessor.java new file mode 100644 index 00000000000..2d9e331c991 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/oauth2/client/DefaultOAuth2AuthorizedClientManagerPostProcessor.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.oauth2.client; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.core.Ordered; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultPasswordTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.web.client.RestOperations; + +/** + * A {@link BeanDefinitionRegistryPostProcessor} that registers a {@link DefaultOAuth2AuthorizedClientManager} + * {@link BeanDefinition} with the name {@link OAuth2ClientBeanNames#DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER}. + * + * @author Joe Grandja + * @since 5.4 + */ +public final class DefaultOAuth2AuthorizedClientManagerPostProcessor implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware, Ordered { + private BeanFactory beanFactory; + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + if (registry.containsBeanDefinition(OAuth2ClientBeanNames.DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER)) { + // Return allowing for bean override + return; + } + + boolean clientRegistrationRepositoryAvailable = + BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory, + ClientRegistrationRepository.class, false, false).length == 1; + boolean authorizedClientRepositoryAvailable = + BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory, + OAuth2AuthorizedClientRepository.class, false, false).length == 1; + + if (clientRegistrationRepositoryAvailable && authorizedClientRepositoryAvailable) { + AbstractBeanDefinition beanDefinition = + BeanDefinitionBuilder.genericBeanDefinition(DefaultOAuth2AuthorizedClientManagerFactory.class) + .getBeanDefinition(); + registry.registerBeanDefinition(OAuth2ClientBeanNames.DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER, beanDefinition); + } + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + + private static class DefaultOAuth2AuthorizedClientManagerFactory implements FactoryBean, ApplicationContextAware { + private ApplicationContext applicationContext; + private OAuth2AuthorizedClientManager authorizedClientManager; + + @Override + public OAuth2AuthorizedClientManager getObject() throws Exception { + if (this.authorizedClientManager == null) { + this.authorizedClientManager = createDefaultAuthorizedClientManager(); + } + return this.authorizedClientManager; + } + + @Override + public Class getObjectType() { + return OAuth2AuthorizedClientManager.class; + } + + @Override + public boolean isSingleton() { + return true; + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.applicationContext = applicationContext; + } + + private OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager() { + ClientRegistrationRepository clientRegistrationRepository = + this.applicationContext.getBean(ClientRegistrationRepository.class); + OAuth2AuthorizedClientRepository authorizedClientRepository = + this.applicationContext.getBean(OAuth2AuthorizedClientRepository.class); + RestOperations restOperations = this.applicationContext.getBean( + OAuth2ClientBeanNames.REST_OPERATIONS, RestOperations.class); + + DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = + new DefaultRefreshTokenTokenResponseClient(); + refreshTokenTokenResponseClient.setRestOperations(restOperations); + + DefaultClientCredentialsTokenResponseClient clientCredentialsTokenResponseClient = + new DefaultClientCredentialsTokenResponseClient(); + clientCredentialsTokenResponseClient.setRestOperations(restOperations); + + DefaultPasswordTokenResponseClient passwordTokenResponseClient = + new DefaultPasswordTokenResponseClient(); + passwordTokenResponseClient.setRestOperations(restOperations); + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken(configurer -> configurer.accessTokenResponseClient(refreshTokenTokenResponseClient)) + .clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) + .password(configurer -> configurer.accessTokenResponseClient(passwordTokenResponseClient)) + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; + } + } +} diff --git a/config/src/main/java/org/springframework/security/config/oauth2/client/OAuth2ClientBeanNames.java b/config/src/main/java/org/springframework/security/config/oauth2/client/OAuth2ClientBeanNames.java new file mode 100644 index 00000000000..8d344aa71b7 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/oauth2/client/OAuth2ClientBeanNames.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.oauth2.client; + +import org.springframework.context.annotation.Bean; + +/** + * {@link Bean} names used (reserved) for OAuth 2.0 Client support. + * + * @author Joe Grandja + * @since 5.4 + */ +public interface OAuth2ClientBeanNames { + + String REST_OPERATIONS = "org.springframework.security.oauth2.client.restOperations"; + + String DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER = "org.springframework.security.oauth2.client.defaultOAuth2AuthorizedClientManager"; + +} diff --git a/config/src/main/java/org/springframework/security/config/oauth2/client/OAuth2ClientRestOperationsPostProcessor.java b/config/src/main/java/org/springframework/security/config/oauth2/client/OAuth2ClientRestOperationsPostProcessor.java new file mode 100644 index 00000000000..c7d6861b679 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/oauth2/client/OAuth2ClientRestOperationsPostProcessor.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.oauth2.client; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.core.Ordered; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; + +/** + * A {@link BeanDefinitionRegistryPostProcessor} that registers a {@link RestOperations} + * {@link BeanDefinition} with the name {@link OAuth2ClientBeanNames#REST_OPERATIONS}. + * + * @author Joe Grandja + * @since 5.4 + */ +public final class OAuth2ClientRestOperationsPostProcessor implements BeanDefinitionRegistryPostProcessor, Ordered { + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + if (registry.containsBeanDefinition(OAuth2ClientBeanNames.REST_OPERATIONS)) { + // Return allowing for bean override + return; + } + + AbstractBeanDefinition beanDefinition = + BeanDefinitionBuilder.genericBeanDefinition(OAuth2ClientRestOperationsFactory.class) + .getBeanDefinition(); + registry.registerBeanDefinition(OAuth2ClientBeanNames.REST_OPERATIONS, beanDefinition); + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + + private static class OAuth2ClientRestOperationsFactory implements FactoryBean { + private RestOperations restOperations; + + @Override + public RestOperations getObject() throws Exception { + if (this.restOperations == null) { + this.restOperations = createRestOperations(); + } + return this.restOperations; + } + + @Override + public Class getObjectType() { + return RestOperations.class; + } + + @Override + public boolean isSingleton() { + return true; + } + + private RestOperations createRestOperations() { + RestTemplate restTemplate = new RestTemplate(); + restTemplate.getMessageConverters().add(new OAuth2AccessTokenResponseHttpMessageConverter()); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + return restTemplate; + } + } +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java index 8c42b037d7f..17a669e2565 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,27 +21,42 @@ import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.RequestEntity; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.ClientRegistrations; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.test.web.servlet.MockMvc; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import javax.servlet.http.HttpServletRequest; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; @@ -91,7 +106,7 @@ public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws this.mockMvc.perform(get("/authorized-client").with(authentication(authentication))) .andExpect(status().isOk()) .andExpect(content().string("resolved")); - verifyZeroInteractions(accessTokenResponseClient); + verifyNoInteractions(accessTokenResponseClient); } @Test @@ -314,4 +329,98 @@ public OAuth2AccessTokenResponseClient acce return mock(OAuth2AccessTokenResponseClient.class); } } + + @Test + public void loadContextWhenPostProcessedThenRestOperationsRegistered() { + this.spring.register(OAuth2ClientBeansRegisteredConfig.class).autowire(); + + assertThat(this.spring.getContext().containsBean(OAuth2ClientBeanNames.REST_OPERATIONS)).isTrue(); + assertThat(this.spring.getContext().getBean(OAuth2ClientBeanNames.REST_OPERATIONS, + RestOperations.class)).isInstanceOf(RestTemplate.class); + } + + @Test + public void loadContextWhenPostProcessedThenDefaultOAuth2AuthorizedClientManagerRegistered() { + this.spring.register(OAuth2ClientBeansRegisteredConfig.class).autowire(); + + assertThat(this.spring.getContext().containsBean(OAuth2ClientBeanNames.DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER)).isTrue(); + assertThat(this.spring.getContext().getBean(OAuth2ClientBeanNames.DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER, + OAuth2AuthorizedClientManager.class)).isInstanceOf(DefaultOAuth2AuthorizedClientManager.class); + } + + @EnableWebSecurity + static class OAuth2ClientBeansRegisteredConfig extends WebSecurityConfigurerAdapter { + + @Bean + public ClientRegistrationRepository clientRegistrationRepository() { + return mock(ClientRegistrationRepository.class); + } + + @Bean + public OAuth2AuthorizedClientRepository authorizedClientRepository() { + return mock(OAuth2AuthorizedClientRepository.class); + } + } + + @Test + public void loadContextWhenPostProcessedAndBeansNotRegisteredThenDefaultOAuth2AuthorizedClientManagerNotRegistered() { + this.spring.register(OAuth2ClientBeansNotRegisteredConfig.class).autowire(); + + assertThat(this.spring.getContext().containsBean(OAuth2ClientBeanNames.DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER)).isFalse(); + } + + @EnableWebSecurity + static class OAuth2ClientBeansNotRegisteredConfig extends WebSecurityConfigurerAdapter { + } + + @Test + public void loadContextWhenOverrideBeansThenOverridden() { + this.spring.register(OAuth2ClientBeanOverridesConfig.class).autowire(); + + assertThat(this.spring.getContext().getBean(OAuth2ClientBeanNames.REST_OPERATIONS, + RestOperations.class)).isSameAs(OAuth2ClientBeanOverridesConfig.restOperations); + assertThat(this.spring.getContext().getBean(OAuth2ClientBeanNames.DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER, + OAuth2AuthorizedClientManager.class)).isSameAs(OAuth2ClientBeanOverridesConfig.authorizedClientManager); + } + + @EnableWebSecurity + static class OAuth2ClientBeanOverridesConfig extends WebSecurityConfigurerAdapter { + static RestOperations restOperations = mock(RestOperations.class); + static OAuth2AuthorizedClientManager authorizedClientManager = mock(OAuth2AuthorizedClientManager.class); + + @Bean(OAuth2ClientBeanNames.REST_OPERATIONS) + public RestOperations restOperations() { + return restOperations; + } + + @Bean(OAuth2ClientBeanNames.DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER) + public OAuth2AuthorizedClientManager authorizedClientManager() { + return authorizedClientManager; + } + + @Bean + public ClientRegistrationRepository clientRegistrationRepository() { + return mock(ClientRegistrationRepository.class); + } + + @Bean + public OAuth2AuthorizedClientRepository authorizedClientRepository() { + return mock(OAuth2AuthorizedClientRepository.class); + } + } + + @Test + public void loadContextWhenRestOperationsRegisteredThenClientRegistrationsUses() { + this.spring.register(OAuth2ClientBeanOverridesConfig.class).autowire(); + + when(OAuth2ClientBeanOverridesConfig.restOperations.exchange( + any(RequestEntity.class), any(ParameterizedTypeReference.class))) + .thenThrow(new IllegalStateException()); + + assertThatThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation("https://invalid.issuer.com")) + .isInstanceOf(IllegalStateException.class); + + verify(OAuth2ClientBeanOverridesConfig.restOperations).exchange( + any(RequestEntity.class), any(ParameterizedTypeReference.class)); + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java index ffc06ee6b02..6534bce4154 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java @@ -20,6 +20,8 @@ import org.junit.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; @@ -27,6 +29,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; @@ -55,6 +58,7 @@ import org.springframework.test.web.servlet.MvcResult; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.client.RestOperations; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import javax.servlet.http.HttpServletRequest; @@ -63,7 +67,11 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; @@ -263,6 +271,52 @@ public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationR verify(authorizationRequestResolver).resolve(any()); } + @Test + public void configureWhenRestOperationsProvidedAndClientAuthorizationSucceedsThenRestOperationsUsed() throws Exception { + this.spring.register(OAuth2ClientRestOperationsConfig.class).autowire(); + + // Setup the Authorization Request in the session + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId()); + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri()) + .clientId(this.registration1.getClientId()) + .redirectUri("http://localhost/client-1") + .state("state") + .attributes(attributes) + .build(); + + AuthorizationRequestRepository authorizationRequestRepository = + new HttpSessionOAuth2AuthorizationRequestRepository(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); + + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(300) + .build(); + when(OAuth2ClientRestOperationsConfig.restOperations.exchange( + any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .thenReturn(ResponseEntity.ok(accessTokenResponse)); + + MockHttpSession session = (MockHttpSession) request.getSession(); + + String principalName = "user1"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); + + this.mockMvc.perform(get("/client-1") + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .with(authentication(authentication)) + .session(session)) + .andExpect(status().is3xxRedirection()) + .andExpect(redirectedUrl("http://localhost/client-1")); + + verify(OAuth2ClientRestOperationsConfig.restOperations).exchange( + any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + @EnableWebSecurity @EnableWebMvc static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter { @@ -325,4 +379,44 @@ public OAuth2AuthorizedClientRepository authorizedClientRepository() { return authorizedClientRepository; } } + + @EnableWebSecurity + @EnableWebMvc + static class OAuth2ClientRestOperationsConfig extends WebSecurityConfigurerAdapter { + static RestOperations restOperations = mock(RestOperations.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2Client() + .authorizationCodeGrant() + .authorizationRequestResolver(authorizationRequestResolver); + } + + @Bean(OAuth2ClientBeanNames.REST_OPERATIONS) + public RestOperations restOperations() { + return restOperations; + } + + @Bean + public ClientRegistrationRepository clientRegistrationRepository() { + return clientRegistrationRepository; + } + + @Bean + public OAuth2AuthorizedClientRepository authorizedClientRepository() { + return authorizedClientRepository; + } + + @RestController + public class ResourceController { + @GetMapping("/resource1") + public String resource1(@RegisteredOAuth2AuthorizedClient("registration-1") OAuth2AuthorizedClient authorizedClient) { + return "resource1"; + } + } + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index 48274c5e08a..28a306659b1 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,26 +15,21 @@ */ package org.springframework.security.config.annotation.web.configurers.oauth2.client; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.apache.http.HttpHeaders; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; - import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -43,6 +38,7 @@ import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; +import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; @@ -85,12 +81,22 @@ import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.client.RestOperations; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; import static org.springframework.security.oauth2.jwt.TestJwts.jwt; @@ -585,6 +591,45 @@ public void logoutWhenUsingOidcLogoutHandlerThenRedirects() throws Exception { .andExpect(redirectedUrl("https://logout?id_token_hint=id-token")); } + @Test + public void oidcLoginWithCustomRestOperationsThenUsed() throws Exception { + // setup application context + loadConfig(OAuth2LoginConfigCustomRestOperations.class, JwtDecoderFactoryConfig.class); + + // setup authorization request + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); + this.authorizationRequestRepository.saveAuthorizationRequest( + authorizationRequest, this.request, this.response); + + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("accessToken123") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, "token123")) + .build(); + when(OAuth2LoginConfigCustomRestOperations.restOperations.exchange( + any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .thenReturn(ResponseEntity.ok(accessTokenResponse)); + + ParameterizedTypeReference> parameterizedType = + new ParameterizedTypeReference>() {}; + Map userInfoResponse = TestOidcUsers.create().getUserInfo().getClaims(); + when(OAuth2LoginConfigCustomRestOperations.restOperations.exchange( + any(RequestEntity.class), eq(parameterizedType))) + .thenReturn(ResponseEntity.ok(userInfoResponse)); + + // setup authentication parameters + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + + // perform test + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + + // assertions + verify(OAuth2LoginConfigCustomRestOperations.restOperations).exchange( + any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + verify(OAuth2LoginConfigCustomRestOperations.restOperations).exchange( + any(RequestEntity.class), eq(parameterizedType)); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -856,6 +901,33 @@ ClientRegistrationRepository clientRegistrationRepository() { } } + @EnableWebSecurity + static class OAuth2LoginConfigCustomRestOperations extends CommonWebSecurityConfigurerAdapter { + static RestOperations restOperations = mock(RestOperations.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .securityContext() + .securityContextRepository(securityContextRepository()) + .and() + .oauth2Login(); + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION); + } + + @Bean(OAuth2ClientBeanNames.REST_OPERATIONS) + RestOperations restOperations() { + return restOperations; + } + } + private static abstract class CommonWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java index cf552a6ece9..ed603e15406 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java @@ -59,6 +59,7 @@ import org.springframework.context.EnvironmentAware; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.convert.converter.Converter; import org.springframework.core.env.ConfigurableEnvironment; @@ -2397,6 +2398,7 @@ public Object getProperty(String name) { static class RestOperationsConfig { RestOperations rest = mock(RestOperations.class); + @Primary @Bean RestOperations rest() { return this.rest; diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java index 6774cfcdf2b..025119f1a0f 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java @@ -20,6 +20,8 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; @@ -41,12 +43,14 @@ import org.springframework.test.web.servlet.MvcResult; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestOperations; import java.util.HashMap; import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse; @@ -85,6 +89,9 @@ public class OAuth2ClientBeanDefinitionParserTests { @Autowired(required = false) private OAuth2AccessTokenResponseClient accessTokenResponseClient; + @Autowired(required = false) + private RestOperations restOperations; + @Autowired private MockMvc mvc; @@ -200,6 +207,33 @@ public void requestWhenCustomAuthorizedClientServiceThenCalled() throws Exceptio verify(this.authorizedClientService).saveAuthorizedClient(any(), any()); } + @WithMockUser + @Test + public void requestWhenCustomRestOperationsThenCalled() throws Exception { + this.spring.configLocations(xml("CustomRestOperations")).autowire(); + + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); + + OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest(clientRegistration); + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .thenReturn(authorizationRequest); + when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .thenReturn(authorizationRequest); + + OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); + when(this.restOperations.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .thenReturn(ResponseEntity.ok(accessTokenResponse)); + + MultiValueMap params = new LinkedMultiValueMap<>(); + params.add("code", "code123"); + params.add("state", authorizationRequest.getState()); + this.mvc.perform(get(authorizationRequest.getRedirectUri()).params(params)) + .andExpect(status().is3xxRedirection()) + .andExpect(redirectedUrl(authorizationRequest.getRedirectUri())); + + verify(this.restOperations).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + private static OAuth2AuthorizationRequest createAuthorizationRequest(ClientRegistration clientRegistration) { Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java index 9470cd03b46..f972582d4fe 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java @@ -20,7 +20,10 @@ import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import org.springframework.security.authentication.event.AuthenticationSuccessEvent; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.Authentication; @@ -57,9 +60,15 @@ import org.springframework.test.web.servlet.MvcResult; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestOperations; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -69,10 +78,6 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; - /** * Tests for {@link OAuth2LoginBeanDefinitionParser}. * @@ -108,6 +113,9 @@ public class OAuth2LoginBeanDefinitionParserTests { @Autowired(required = false) private OAuth2UserService oauth2UserService; + @Autowired(required = false) + private RestOperations restOperations; + @Autowired(required = false) private JwtDecoderFactory jwtDecoderFactory; @@ -489,6 +497,39 @@ public void requestWhenCustomAuthorizedClientServiceThenCalled() throws Exceptio verify(authorizedClientService).saveAuthorizedClient(any(), any()); } + @Test + public void requestWhenCustomRestOperationsThenCalled() throws Exception { + this.spring.configLocations(this.xml("WithCustomRestOperations")).autowire(); + + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().userNameAttributeName("username").build(); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(clientRegistration); + + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .attributes(attributes).build(); + when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); + + OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); + when(this.restOperations.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .thenReturn(ResponseEntity.ok(accessTokenResponse)); + + ParameterizedTypeReference> parameterizedType = + new ParameterizedTypeReference>() {}; + Map userInfoResponse = TestOAuth2Users.create().getAttributes(); + when(this.restOperations.exchange( + any(RequestEntity.class), eq(parameterizedType))) + .thenReturn(ResponseEntity.ok(userInfoResponse)); + + MultiValueMap params = new LinkedMultiValueMap<>(); + params.add("code", "code123"); + params.add("state", authorizationRequest.getState()); + this.mvc.perform(get("/login/oauth2/code/" + clientRegistration.getRegistrationId()).params(params)); + + verify(this.restOperations).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + verify(this.restOperations).exchange(any(RequestEntity.class), eq(parameterizedType)); + } + private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomRestOperations.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomRestOperations.xml new file mode 100644 index 00000000000..98c9f8438b5 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomRestOperations.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-WithCustomRestOperations.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-WithCustomRestOperations.xml new file mode 100644 index 00000000000..577d3793058 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-WithCustomRestOperations.xml @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java index 857b150db09..81d601649c0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java @@ -16,20 +16,12 @@ package org.springframework.security.oauth2.client.registration; -import java.net.URI; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Supplier; - import com.nimbusds.oauth2.sdk.GrantType; import com.nimbusds.oauth2.sdk.ParseException; import com.nimbusds.oauth2.sdk.Scope; import com.nimbusds.oauth2.sdk.as.AuthorizationServerMetadata; import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; import net.minidev.json.JSONObject; - import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -38,9 +30,17 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.util.Assert; import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; +import java.net.URI; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + /** * Allows creating a {@link ClientRegistration.Builder} from an * OpenID Provider Configuration @@ -55,7 +55,7 @@ public final class ClientRegistrations { private static final String OIDC_METADATA_PATH = "/.well-known/openid-configuration"; private static final String OAUTH_METADATA_PATH = "/.well-known/oauth-authorization-server"; - private static final RestTemplate rest = new RestTemplate(); + private static RestOperations rest = new RestTemplate(); private static final ParameterizedTypeReference> typeReference = new ParameterizedTypeReference>() {}; @@ -138,6 +138,17 @@ public static ClientRegistration.Builder fromIssuerLocation(String issuer) { return getBuilder(issuer, oidc(uri), oidcRfc8414(uri), oauth(uri)); } + /** + * Sets the {@link RestOperations} used when requesting the discovery endpoint. + * + * @since 5.4 + * @param restOperations the {@link RestOperations} used when requesting the discovery endpoint + */ + public static void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + rest = restOperations; + } + private static Supplier oidc(URI issuer) { URI uri = UriComponentsBuilder.fromUri(issuer) .replacePath(issuer.getPath() + OIDC_METADATA_PATH).build(Collections.emptyMap()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java index 03677717b18..8bb1a614876 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java @@ -122,6 +122,12 @@ public void cleanup() throws Exception { this.server.shutdown(); } + @Test + public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> ClientRegistrations.setRestOperations(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void issuerWhenAllInformationThenSuccess() throws Exception { ClientRegistration registration = registration("").build(); From 19f57f99f75d29b7c8adb75ca380c62e6825cd0d Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 24 Jun 2020 09:02:00 -0400 Subject: [PATCH 2/3] Remove ClientRegistrations.setRestOperations() --- .../OAuth2ClientConfiguration.java | 17 +---------- .../OAuth2ClientConfigurationTests.java | 18 ------------ .../registration/ClientRegistrations.java | 29 ++++++------------- .../registration/ClientRegistrationsTest.java | 6 ---- 4 files changed, 10 insertions(+), 60 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 2a607ad5475..95e5b9c94fa 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -16,7 +16,6 @@ package org.springframework.security.config.annotation.web.configuration; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -24,19 +23,16 @@ import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; import org.springframework.security.config.oauth2.client.DefaultOAuth2AuthorizedClientManagerPostProcessor; -import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; import org.springframework.security.config.oauth2.client.OAuth2ClientRestOperationsPostProcessor; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.registration.ClientRegistrations; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.util.ClassUtils; -import org.springframework.web.client.RestOperations; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; @@ -55,8 +51,7 @@ * @see OAuth2ImportSelector */ @Import({OAuth2ClientConfiguration.OAuth2ClientSecurityConfiguration.class, - OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class, - OAuth2ClientConfiguration.OAuth2ClientRegistrationsConfiguration.class}) + OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class}) final class OAuth2ClientConfiguration { @Configuration(proxyBeanMethods = false) @@ -134,14 +129,4 @@ void setAccessTokenResponseClient( accessTokenResponseClient.ifPresent(client -> this.accessTokenResponseClient = client); } } - - @Configuration(proxyBeanMethods = false) - static class OAuth2ClientRegistrationsConfiguration { - - @Autowired - @Qualifier(OAuth2ClientBeanNames.REST_OPERATIONS) - void configure(RestOperations restOperations) { - ClientRegistrations.setRestOperations(restOperations); - } - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java index 17a669e2565..de65e9ef7c6 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java @@ -21,8 +21,6 @@ import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.RequestEntity; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; @@ -34,7 +32,6 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.registration.ClientRegistrations; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -408,19 +405,4 @@ public OAuth2AuthorizedClientRepository authorizedClientRepository() { return mock(OAuth2AuthorizedClientRepository.class); } } - - @Test - public void loadContextWhenRestOperationsRegisteredThenClientRegistrationsUses() { - this.spring.register(OAuth2ClientBeanOverridesConfig.class).autowire(); - - when(OAuth2ClientBeanOverridesConfig.restOperations.exchange( - any(RequestEntity.class), any(ParameterizedTypeReference.class))) - .thenThrow(new IllegalStateException()); - - assertThatThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation("https://invalid.issuer.com")) - .isInstanceOf(IllegalStateException.class); - - verify(OAuth2ClientBeanOverridesConfig.restOperations).exchange( - any(RequestEntity.class), any(ParameterizedTypeReference.class)); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java index 81d601649c0..857b150db09 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java @@ -16,12 +16,20 @@ package org.springframework.security.oauth2.client.registration; +import java.net.URI; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + import com.nimbusds.oauth2.sdk.GrantType; import com.nimbusds.oauth2.sdk.ParseException; import com.nimbusds.oauth2.sdk.Scope; import com.nimbusds.oauth2.sdk.as.AuthorizationServerMetadata; import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; import net.minidev.json.JSONObject; + import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -30,17 +38,9 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.util.Assert; import org.springframework.web.client.HttpClientErrorException; -import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; -import java.net.URI; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Supplier; - /** * Allows creating a {@link ClientRegistration.Builder} from an * OpenID Provider Configuration @@ -55,7 +55,7 @@ public final class ClientRegistrations { private static final String OIDC_METADATA_PATH = "/.well-known/openid-configuration"; private static final String OAUTH_METADATA_PATH = "/.well-known/oauth-authorization-server"; - private static RestOperations rest = new RestTemplate(); + private static final RestTemplate rest = new RestTemplate(); private static final ParameterizedTypeReference> typeReference = new ParameterizedTypeReference>() {}; @@ -138,17 +138,6 @@ public static ClientRegistration.Builder fromIssuerLocation(String issuer) { return getBuilder(issuer, oidc(uri), oidcRfc8414(uri), oauth(uri)); } - /** - * Sets the {@link RestOperations} used when requesting the discovery endpoint. - * - * @since 5.4 - * @param restOperations the {@link RestOperations} used when requesting the discovery endpoint - */ - public static void setRestOperations(RestOperations restOperations) { - Assert.notNull(restOperations, "restOperations cannot be null"); - rest = restOperations; - } - private static Supplier oidc(URI issuer) { URI uri = UriComponentsBuilder.fromUri(issuer) .replacePath(issuer.getPath() + OIDC_METADATA_PATH).build(Collections.emptyMap()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java index 8bb1a614876..03677717b18 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java @@ -122,12 +122,6 @@ public void cleanup() throws Exception { this.server.shutdown(); } - @Test - public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> ClientRegistrations.setRestOperations(null)) - .isInstanceOf(IllegalArgumentException.class); - } - @Test public void issuerWhenAllInformationThenSuccess() throws Exception { ClientRegistration registration = registration("").build(); From 0a724f320a4af1e52aa647c6e4e44212a2e3544c Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 24 Jun 2020 09:14:49 -0400 Subject: [PATCH 3/3] Update sample --- .../java/sample/config/WebClientConfig.java | 26 +++---------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java index da9510602bd..8b54c74d483 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java @@ -16,15 +16,12 @@ package sample.config; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction; import org.springframework.web.reactive.function.client.WebClient; @@ -38,7 +35,7 @@ public class WebClientConfig { @Value("${resource-uri}") String resourceUri; @Bean - WebClient webClient(OAuth2AuthorizedClientManager authorizedClientManager) { + WebClient webClient(@Qualifier(OAuth2ClientBeanNames.DEFAULT_OAUTH2_AUTHORIZED_CLIENT_MANAGER) OAuth2AuthorizedClientManager authorizedClientManager) { ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); oauth2.setDefaultOAuth2AuthorizedClient(true); @@ -47,21 +44,4 @@ WebClient webClient(OAuth2AuthorizedClientManager authorizedClientManager) { .apply(oauth2.oauth2Configuration()) .build(); } - - @Bean - OAuth2AuthorizedClientManager authorizedClientManager(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository) { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .password() - .build(); - DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - clientRegistrationRepository, authorizedClientRepository); - authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - - return authorizedClientManager; - } }