Skip to content

Register RestOperations @Bean for OAuth 2.0 Client #8732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,10 +16,14 @@
package org.springframework.security.config.annotation.web.configuration;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.security.config.oauth2.client.DefaultOAuth2AuthorizedClientManagerPostProcessor;
import org.springframework.security.config.oauth2.client.OAuth2ClientRestOperationsPostProcessor;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
Expand All @@ -46,19 +50,34 @@
* @since 5.1
* @see OAuth2ImportSelector
*/
@Import(OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class)
@Import({OAuth2ClientConfiguration.OAuth2ClientSecurityConfiguration.class,
OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class})
final class OAuth2ClientConfiguration {

@Configuration(proxyBeanMethods = false)
static class OAuth2ClientSecurityConfiguration {

@Bean
BeanDefinitionRegistryPostProcessor oauth2ClientRestOperationsPostProcessor() {
return new OAuth2ClientRestOperationsPostProcessor();
}

@Bean
BeanDefinitionRegistryPostProcessor defaultOAuth2AuthorizedClientManagerPostProcessor() {
return new DefaultOAuth2AuthorizedClientManagerPostProcessor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble seeing the value of registering OAuth2AuthorizedClientManager with a known bean name.

Doing so with RestOperations makes sense since there are many use cases when an application publishes a RestOperations but doesn't intend it to be used by Spring Security.

What is it about OAuth2AuthorizedClientManager that requires the same treatment as RestOperations? Is this a needed pattern for other OAuth 2.0 Client beans?

Copy link
Contributor Author

@jgrandja jgrandja Jun 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OAuth2AuthorizedClientManager is the central interface for client since 5.2. It is used primarily by ServletOAuth2AuthorizedClientExchangeFilterFunction and OAuth2AuthorizedClientArgumentResolver. It indirectly uses RestOperations. The relationship chain is as follows:

OAuth2AuthorizedClientManager has 1 or more OAuth2AuthorizedClientProvider
OAuth2AuthorizedClientProvider has an OAuth2AccessTokenResponseClient
OAuth2AccessTokenResponseClient has a RestOperations

If an application configures a custom RestOperations then it will automatically be configured with the "default" OAuth2AuthorizedClientManager and associated OAuth2AuthorizedClientProvider's, and the user could leverage this default @Bean instead of explicitly configuring it. It really is meant to be a convenience mechanism.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One important thing to note is that an application may configure more than 1 OAuth2AuthorizedClientManager in their setup. For example, they may have a background service that performs client_credentials grant and another OAuth2AuthorizedClientManager that handles authorization_code and refresh_token in the web context.

The "default" OAuth2AuthorizedClientManager will be wired into the OAuth2AuthorizedClientArgumentResolver and may be used by ServletOAuth2AuthorizedClientExchangeFilterFunction.

}
}

static class OAuth2ClientWebMvcImportSelector implements ImportSelector {

@Override
public String[] selectImports(AnnotationMetadata importingClassMetadata) {
boolean webmvcPresent = ClassUtils.isPresent(
"org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader());
"org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader());

return webmvcPresent ?
new String[] { "org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" } :
new String[] {};
new String[] { "org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" } :
new String[] {};
}
}

Expand Down Expand Up @@ -91,21 +110,21 @@ public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentRes
}

@Autowired(required = false)
public void setClientRegistrationRepository(List<ClientRegistrationRepository> clientRegistrationRepositories) {
void setClientRegistrationRepository(List<ClientRegistrationRepository> clientRegistrationRepositories) {
if (clientRegistrationRepositories.size() == 1) {
this.clientRegistrationRepository = clientRegistrationRepositories.get(0);
}
}

@Autowired(required = false)
public void setAuthorizedClientRepository(List<OAuth2AuthorizedClientRepository> authorizedClientRepositories) {
void setAuthorizedClientRepository(List<OAuth2AuthorizedClientRepository> authorizedClientRepositories) {
if (authorizedClientRepositories.size() == 1) {
this.authorizedClientRepository = authorizedClientRepositories.get(0);
}
}

@Autowired
public void setAccessTokenResponseClient(
void setAccessTokenResponseClient(
Optional<OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest>> accessTokenResponseClient) {
accessTokenResponseClient.ifPresent(client -> this.accessTokenResponseClient = client);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> get
if (this.accessTokenResponseClient != null) {
return this.accessTokenResponseClient;
}
return new DefaultAuthorizationCodeTokenResponseClient();
DefaultAuthorizationCodeTokenResponseClient authorizationCodeTokenResponseClient =
new DefaultAuthorizationCodeTokenResponseClient();
authorizationCodeTokenResponseClient.setRestOperations(
OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder()));
return authorizationCodeTokenResponseClient;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,12 +20,14 @@
import org.springframework.context.ApplicationContext;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations;

import java.util.Map;

Expand Down Expand Up @@ -96,4 +98,9 @@ private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientService
}
return (!authorizedClientServiceMap.isEmpty() ? authorizedClientServiceMap.values().iterator().next() : null);
}

static <B extends HttpSecurityBuilder<B>> RestOperations getRestOperationsBean(B builder) {
return builder.getSharedObject(ApplicationContext.class).getBean(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it would fail if there isn't a bean by that name and type or if there is more than one bean of this type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you think we should use a bean name here? We don't typically do the lookups by bean name.

Copy link
Contributor Author

@jgrandja jgrandja Jun 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it would fail if there isn't a bean by that name and type or if there is more than one bean of this type

OAuth2ClientRestOperationsPostProcessor defined in OAuth2ClientConfiguration will register the @Bean using OAuth2ClientBeanNames.REST_OPERATIONS so it won't fail. It's expected to be registered.

Is there a reason you think we should use a bean name here?

I thought this is what we decided on, as per comment

Furthermore, an application may register one or more RestOperations @Bean so we need to ensure we pick up the correct RestOperations @Bean to be used for the client flows - hence the use of bean name.

Also, see tests in OAuth2ClientConfigurationTests starting at loadContextWhenPostProcessedThenRestOperationsRegistered() which demonstrates that the @Bean is registered by default and how a user could override the @Bean

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than relying on OAuth2ClientRestOperationsPostProcessor could we just default the RestOperations if none was found? This is what we do in other circumstances.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we do in other circumstances

Can you please be specific and provide a code sample I could reference?

Copy link
Contributor Author

@jgrandja jgrandja Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rwinch Are you referring to the way we conditionally @Import using ImportSelector? For example, OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We often try to get a Bean if it doesn't exist we create a default instance. An example is PasswordEncoder

OAuth2ClientBeanNames.REST_OPERATIONS, RestOperations.class);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -503,7 +503,11 @@ public void init(B http) throws Exception {
OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient =
this.tokenEndpointConfig.accessTokenResponseClient;
if (accessTokenResponseClient == null) {
accessTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient();
DefaultAuthorizationCodeTokenResponseClient authorizationCodeTokenResponseClient =
new DefaultAuthorizationCodeTokenResponseClient();
authorizationCodeTokenResponseClient.setRestOperations(
OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder()));
accessTokenResponseClient = authorizationCodeTokenResponseClient;
}

OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService = getOAuth2UserService();
Expand Down Expand Up @@ -619,7 +623,11 @@ private OAuth2UserService<OidcUserRequest, OidcUser> getOidcUserService() {
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OidcUserRequest.class, OidcUser.class);
OAuth2UserService<OidcUserRequest, OidcUser> bean = getBeanOrNull(type);
if (bean == null) {
return new OidcUserService();
DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
userService.setRestOperations(OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder()));
OidcUserService oidcUserService = new OidcUserService();
oidcUserService.setOauth2UserService(userService);
return oidcUserService;
}

return bean;
Expand All @@ -632,13 +640,18 @@ private OAuth2UserService<OAuth2UserRequest, OAuth2User> getOAuth2UserService()
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OAuth2UserRequest.class, OAuth2User.class);
OAuth2UserService<OAuth2UserRequest, OAuth2User> bean = getBeanOrNull(type);
if (bean == null) {
DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
userService.setRestOperations(OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder()));
if (!this.userInfoEndpointConfig.customUserTypes.isEmpty()) {
List<OAuth2UserService<OAuth2UserRequest, OAuth2User>> userServices = new ArrayList<>();
userServices.add(new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes));
userServices.add(new DefaultOAuth2UserService());
CustomUserTypesOAuth2UserService customUserTypesUserService =
new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes);
customUserTypesUserService.setRestOperations(OAuth2ClientConfigurerUtils.getRestOperationsBean(getBuilder()));
userServices.add(customUserTypesUserService);
userServices.add(userService);
return new DelegatingOAuth2UserService<>(userServices);
} else {
return new DefaultOAuth2UserService();
return userService;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,8 @@
*/
package org.springframework.security.config.http;

import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.w3c.dom.Element;

import org.springframework.beans.BeanMetadataElement;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanReference;
Expand All @@ -41,6 +31,8 @@
import org.springframework.security.authentication.RememberMeAuthenticationProvider;
import org.springframework.security.config.BeanIds;
import org.springframework.security.config.Elements;
import org.springframework.security.config.oauth2.client.DefaultOAuth2AuthorizedClientManagerPostProcessor;
import org.springframework.security.config.oauth2.client.OAuth2ClientRestOperationsPostProcessor;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.mapping.SimpleAttributes2GrantedAuthoritiesMapper;
import org.springframework.security.core.authority.mapping.SimpleMappableAttributesRetriever;
Expand All @@ -65,6 +57,15 @@
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
import org.w3c.dom.Element;

import javax.servlet.http.HttpServletRequest;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static org.springframework.security.config.http.SecurityFilters.ANONYMOUS_FILTER;
import static org.springframework.security.config.http.SecurityFilters.BASIC_AUTH_FILTER;
Expand Down Expand Up @@ -169,6 +170,7 @@ final class AuthenticationConfigBuilder {
private BeanDefinition authorizationRequestRedirectFilter;
private BeanDefinition authorizationCodeGrantFilter;
private BeanReference authorizationCodeAuthenticationProviderRef;
private boolean oauth2ClientPostProcessorsRegistered;

private final List<BeanReference> authenticationProviders = new ManagedList<>();
private final Map<BeanDefinition, BeanMetadataElement> defaultDeniedHandlerMappings = new ManagedMap<>();
Expand Down Expand Up @@ -312,6 +314,8 @@ void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authMa
pc.registerBeanComponent(new BeanComponentDefinition(
oauth2LoginOidcAuthProvider, oauth2LoginOidcAuthProviderId));
oauth2LoginOidcAuthenticationProviderRef = new RuntimeBeanReference(oauth2LoginOidcAuthProviderId);

registerOAuth2ClientPostProcessors();
}

void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenticationManager) {
Expand Down Expand Up @@ -342,6 +346,18 @@ void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenti
this.pc.registerBeanComponent(new BeanComponentDefinition(
authorizationCodeAuthenticationProvider, authorizationCodeAuthenticationProviderId));
this.authorizationCodeAuthenticationProviderRef = new RuntimeBeanReference(authorizationCodeAuthenticationProviderId);

registerOAuth2ClientPostProcessors();
}

private void registerOAuth2ClientPostProcessors() {
if (!this.oauth2ClientPostProcessorsRegistered) {
this.pc.getReaderContext().registerWithGeneratedName(
new RootBeanDefinition(OAuth2ClientRestOperationsPostProcessor.class));
this.pc.getReaderContext().registerWithGeneratedName(
new RootBeanDefinition(DefaultOAuth2AuthorizedClientManagerPostProcessor.class));
oauth2ClientPostProcessorsRegistered = true;
}
}

void createOpenIDLoginFilter(BeanReference sessionStrategy, BeanReference authManager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.xml.BeanDefinitionParser;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
Expand Down Expand Up @@ -153,6 +154,7 @@ private BeanMetadataElement getAccessTokenResponseClient(Element element) {
} else {
accessTokenResponseClient = BeanDefinitionBuilder.rootBeanDefinition(
"org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient")
.addPropertyReference("restOperations", OAuth2ClientBeanNames.REST_OPERATIONS)
.getBeanDefinition();
}
return accessTokenResponseClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@
*/
package org.springframework.security.config.http;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.springframework.beans.BeanMetadataElement;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanDefinition;
Expand All @@ -37,6 +30,7 @@
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.Elements;
import org.springframework.security.config.oauth2.client.OAuth2ClientBeanNames;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationProvider;
Expand Down Expand Up @@ -66,6 +60,13 @@
import org.springframework.web.accept.HeaderContentNegotiationStrategy;
import org.w3c.dom.Element;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
* @author Ruby Hartono
* @since 5.3
Expand Down Expand Up @@ -320,8 +321,13 @@ private BeanMetadataElement getOidcUserService(Element element) {
if (!StringUtils.isEmpty(oidcUserServiceRef)) {
oidcUserService = new RuntimeBeanReference(oidcUserServiceRef);
} else {
BeanMetadataElement oauth2UserService = BeanDefinitionBuilder
.rootBeanDefinition("org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService")
.addPropertyReference("restOperations", OAuth2ClientBeanNames.REST_OPERATIONS)
.getBeanDefinition();
oidcUserService = BeanDefinitionBuilder
.rootBeanDefinition("org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService")
.addPropertyValue("oauth2UserService", oauth2UserService)
.getBeanDefinition();
}
return oidcUserService;
Expand All @@ -335,6 +341,7 @@ private BeanMetadataElement getOAuth2UserService(Element element) {
} else {
oauth2UserService = BeanDefinitionBuilder
.rootBeanDefinition("org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService")
.addPropertyReference("restOperations", OAuth2ClientBeanNames.REST_OPERATIONS)
.getBeanDefinition();
}
return oauth2UserService;
Expand All @@ -348,6 +355,7 @@ private BeanMetadataElement getAccessTokenResponseClient(Element element) {
} else {
accessTokenResponseClient = BeanDefinitionBuilder.rootBeanDefinition(
"org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient")
.addPropertyReference("restOperations", OAuth2ClientBeanNames.REST_OPERATIONS)
.getBeanDefinition();
}
return accessTokenResponseClient;
Expand Down
Loading