diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 4151010a846..532d9078b55 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -15,22 +15,26 @@ */ package org.springframework.security.config.annotation.web.configuration; -import java.util.List; -import java.util.Optional; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.util.ClassUtils; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; +import java.util.List; +import java.util.Optional; + /** * {@link Configuration} for OAuth 2.0 Client support. * @@ -67,13 +71,17 @@ static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer @Override public void addArgumentResolvers(List argumentResolvers) { if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) { - OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver = - new OAuth2AuthorizedClientArgumentResolver( - this.clientRegistrationRepository, this.authorizedClientRepository); - if (this.accessTokenResponseClient != null) { - authorizedClientArgumentResolver.setClientCredentialsTokenResponseClient(this.accessTokenResponseClient); - } - argumentResolvers.add(authorizedClientArgumentResolver); + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials(configurer -> + Optional.ofNullable(this.accessTokenResponseClient).ifPresent(configurer::accessTokenResponseClient)) + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + argumentResolvers.add(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager)); } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java index b01ece3c430..d96df5495aa 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java @@ -15,21 +15,6 @@ */ package org.springframework.security.config.annotation.web.configuration; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; - -import javax.servlet.http.HttpServletRequest; import org.junit.Rule; import org.junit.Test; import org.springframework.beans.factory.NoSuchBeanDefinitionException; @@ -53,6 +38,19 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import javax.servlet.http.HttpServletRequest; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + /** * Tests for {@link OAuth2ClientConfiguration}. * @@ -72,8 +70,12 @@ public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + ClientRegistration clientRegistration = clientRegistration().registrationId(clientRegistrationId).build(); + when(clientRegistrationRepository.findByRegistrationId(eq(clientRegistrationId))).thenReturn(clientRegistration); + OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); + when(authorizedClient.getClientRegistration()).thenReturn(clientRegistration); when(authorizedClientRepository.loadAuthorizedClient( eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class))) .thenReturn(authorizedClient); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..7ff23c3ceb5 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.util.Assert; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + */ +public final class AuthorizationCodeOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + + /** + * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Returns {@code null} if authorization is not supported, + * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} + * is not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the client is already authorized. + * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported + */ + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) && + context.getAuthorizedClient() == null) { + // ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectFilter which initiates authorization + throw new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId()); + } + return null; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..36dcb6a4848 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +import java.time.Duration; +import java.time.Instant; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + * @see DefaultClientCredentialsTokenResponseClient + */ +public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private OAuth2AccessTokenResponseClient accessTokenResponseClient = + new DefaultClientCredentialsTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + + /** + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Returns {@code null} if authorization (or re-authorization) is not supported, + * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} + * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} OR + * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or re-authorization) is not supported + */ + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + + ClientRegistration clientRegistration = context.getClientRegistration(); + if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return null; + } + + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + // If client is already authorized but access token is NOT expired than no need for re-authorization + return null; + } + + // As per spec, in section 4.4.3 Access Token Response + // https://tools.ietf.org/html/rfc6749#section-4.4.3 + // A refresh token SHOULD NOT be included. + // + // Therefore, renewing an expired access token (re-authorization) + // is the same as acquiring a new access token (authorization). + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(clientRegistration); + OAuth2AccessTokenResponse tokenResponse = + this.accessTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken()); + } + + private boolean hasTokenExpired(AbstractOAuth2Token token) { + return token.getExpiresAt().isBefore(Instant.now().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant + */ + public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..0343b96071c --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} that simply delegates + * to it's internal {@code List} of {@link OAuth2AuthorizedClientProvider}(s). + *

+ * Each provider is given a chance to + * {@link OAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) authorize} + * the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context + * with the first {@code non-null} {@link OAuth2AuthorizedClient} being returned. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + */ +public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private final List authorizedClientProviders; + + /** + * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param authorizedClientProviders a list of {@link OAuth2AuthorizedClientProvider}(s) + */ + public DelegatingOAuth2AuthorizedClientProvider(OAuth2AuthorizedClientProvider... authorizedClientProviders) { + Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); + this.authorizedClientProviders = Collections.unmodifiableList(Arrays.asList(authorizedClientProviders)); + } + + /** + * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param authorizedClientProviders a {@code List} of {@link OAuth2AuthorizedClientProvider}(s) + */ + public DelegatingOAuth2AuthorizedClientProvider(List authorizedClientProviders) { + Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); + this.authorizedClientProviders = Collections.unmodifiableList(new ArrayList<>(authorizedClientProviders)); + } + + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + return this.authorizedClientProviders.stream() + .map(authorizedClientProvider -> authorizedClientProvider.authorize(context)) + .filter(Objects::nonNull) + .findFirst() + .orElse(null); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java new file mode 100644 index 00000000000..d7aa1aec171 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java @@ -0,0 +1,203 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A context that holds authorization-specific state and is used by an {@link OAuth2AuthorizedClientProvider} + * when attempting to authorize (or re-authorize) an OAuth 2.0 Client. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + */ +public final class OAuth2AuthorizationContext { + /** + * The name of the {@link #getAttribute(String) attribute} + * in the {@link OAuth2AuthorizationContext context} + * associated to the value for the "request scope(s)". + * The value of the attribute is a {@code String[]} of scope(s) to be requested + * by the {@link OAuth2AuthorizationContext#getClientRegistration() client}. + */ + public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = OAuth2AuthorizationContext.class.getName().concat(".REQUEST_SCOPE"); + + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + private Map attributes; + + private OAuth2AuthorizationContext() { + } + + /** + * Returns the {@link ClientRegistration client registration}. + * + * @return the {@link ClientRegistration} + */ + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + + /** + * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} + * if the {@link #withClientRegistration(ClientRegistration) client registration} was supplied. + * + * @return the {@link OAuth2AuthorizedClient} or {@code null} if the client registration was supplied + */ + @Nullable + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + /** + * Returns the {@code Principal} (to be) associated to the authorized client. + * + * @return the {@code Principal} (to be) associated to the authorized client + */ + public Authentication getPrincipal() { + return this.principal; + } + + /** + * Returns the attributes associated to the context. + * + * @return a {@code Map} of the attributes associated to the context + */ + public Map getAttributes() { + return this.attributes; + } + + /** + * Returns the value of an attribute associated to the context or {@code null} if not available. + * + * @param name the name of the attribute + * @param the type of the attribute + * @return the value of the attribute associated to the context + */ + @Nullable + @SuppressWarnings("unchecked") + public T getAttribute(String name) { + return (T) this.getAttributes().get(name); + } + + /** + * Returns a new {@link Builder} initialized with the {@link ClientRegistration}. + * + * @param clientRegistration the {@link ClientRegistration client registration} + * @return the {@link Builder} + */ + public static Builder withClientRegistration(ClientRegistration clientRegistration) { + return new Builder(clientRegistration); + } + + /** + * Returns a new {@link Builder} initialized with the {@link OAuth2AuthorizedClient}. + * + * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} + * @return the {@link Builder} + */ + public static Builder withAuthorizedClient(OAuth2AuthorizedClient authorizedClient) { + return new Builder(authorizedClient); + } + + /** + * A builder for {@link OAuth2AuthorizationContext}. + */ + public static class Builder { + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + private Map attributes; + + private Builder(ClientRegistration clientRegistration) { + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + this.clientRegistration = clientRegistration; + } + + private Builder(OAuth2AuthorizedClient authorizedClient) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + this.authorizedClient = authorizedClient; + } + + /** + * Sets the {@code Principal} (to be) associated to the authorized client. + * + * @param principal the {@code Principal} (to be) associated to the authorized client + * @return the {@link Builder} + */ + public Builder principal(Authentication principal) { + this.principal = principal; + return this; + } + + /** + * Sets the attributes associated to the context. + * + * @param attributes the attributes associated to the context + * @return the {@link Builder} + */ + public Builder attributes(Map attributes) { + this.attributes = attributes; + return this; + } + + /** + * Sets an attribute associated to the context. + * + * @param name the name of the attribute + * @param value the value of the attribute + * @return the {@link Builder} + */ + public Builder attribute(String name, Object value) { + if (this.attributes == null) { + this.attributes = new HashMap<>(); + } + this.attributes.put(name, value); + return this; + } + + /** + * Builds a new {@link OAuth2AuthorizationContext}. + * + * @return a {@link OAuth2AuthorizationContext} + */ + public OAuth2AuthorizationContext build() { + Assert.notNull(this.principal, "principal cannot be null"); + OAuth2AuthorizationContext context = new OAuth2AuthorizationContext(); + if (this.authorizedClient != null) { + context.clientRegistration = this.authorizedClient.getClientRegistration(); + context.authorizedClient = this.authorizedClient; + } else { + context.clientRegistration = this.clientRegistration; + } + context.principal = this.principal; + context.attributes = Collections.unmodifiableMap( + CollectionUtils.isEmpty(this.attributes) ? + Collections.emptyMap() : new LinkedHashMap<>(this.attributes)); + return context; + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..b73fc8aafcb --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; + +/** + * A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client. + * Implementations will typically implement a specific {@link AuthorizationGrantType authorization grant} type. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizationContext + * @see Section 1.3 Authorization Grant + */ +public interface OAuth2AuthorizedClientProvider { + + /** + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context. + * Implementations must return {@code null} if authorization is not supported for the specified client, + * e.g. the provider doesn't support the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. + * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported for the specified client + */ + @Nullable + OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context); + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java new file mode 100644 index 00000000000..6405e3bc13e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java @@ -0,0 +1,267 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.util.Assert; + +import java.time.Duration; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +/** + * A builder that builds a {@link DelegatingOAuth2AuthorizedClientProvider} composed of + * one or more {@link OAuth2AuthorizedClientProvider}(s) that implement specific authorization grants. + * The supported authorization grants are {@link #authorizationCode() authorization_code}, + * {@link #refreshToken() refresh_token} and {@link #clientCredentials() client_credentials}. + * In addition to the standard authorization grants, an implementation of an extension grant + * may be supplied via {@link #provider(OAuth2AuthorizedClientProvider)}. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + * @see AuthorizationCodeOAuth2AuthorizedClientProvider + * @see RefreshTokenOAuth2AuthorizedClientProvider + * @see ClientCredentialsOAuth2AuthorizedClientProvider + * @see DelegatingOAuth2AuthorizedClientProvider + */ +public final class OAuth2AuthorizedClientProviderBuilder { + private final Map, Builder> builders = new LinkedHashMap<>(); + + private OAuth2AuthorizedClientProviderBuilder() { + } + + /** + * Returns a new {@link OAuth2AuthorizedClientProviderBuilder} for configuring the supported authorization grant(s). + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public static OAuth2AuthorizedClientProviderBuilder builder() { + return new OAuth2AuthorizedClientProviderBuilder(); + } + + /** + * Configures an {@link OAuth2AuthorizedClientProvider} to be composed with the {@link DelegatingOAuth2AuthorizedClientProvider}. + * This may be used for implementations of extension authorization grants. + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder provider(OAuth2AuthorizedClientProvider provider) { + Assert.notNull(provider, "provider cannot be null"); + this.builders.computeIfAbsent(provider.getClass(), k -> () -> provider); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code authorization_code} grant. + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder authorizationCode() { + this.builders.computeIfAbsent(AuthorizationCodeOAuth2AuthorizedClientProvider.class, k -> new AuthorizationCodeGrantBuilder()); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code authorization_code} grant. + */ + public class AuthorizationCodeGrantBuilder implements Builder { + + private AuthorizationCodeGrantBuilder() { + } + + /** + * Builds an instance of {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. + * + * @return the {@link AuthorizationCodeOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + return new AuthorizationCodeOAuth2AuthorizedClientProvider(); + } + } + + /** + * Configures support for the {@code refresh_token} grant. + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder refreshToken() { + this.builders.computeIfAbsent(RefreshTokenOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code refresh_token} grant. + * + * @param builderConsumer a {@code Consumer} of {@link RefreshTokenGrantBuilder} used for further configuration + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder refreshToken(Consumer builderConsumer) { + RefreshTokenGrantBuilder builder = (RefreshTokenGrantBuilder) this.builders.computeIfAbsent( + RefreshTokenOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); + builderConsumer.accept(builder); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code refresh_token} grant. + */ + public class RefreshTokenGrantBuilder implements Builder { + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + + private RefreshTokenGrantBuilder() { + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder accessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Builds an instance of {@link RefreshTokenOAuth2AuthorizedClientProvider}. + * + * @return the {@link RefreshTokenOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + return authorizedClientProvider; + } + } + + /** + * Configures support for the {@code client_credentials} grant. + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder clientCredentials() { + this.builders.computeIfAbsent(ClientCredentialsOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code client_credentials} grant. + * + * @param builderConsumer a {@code Consumer} of {@link ClientCredentialsGrantBuilder} used for further configuration + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder clientCredentials(Consumer builderConsumer) { + ClientCredentialsGrantBuilder builder = (ClientCredentialsGrantBuilder) this.builders.computeIfAbsent( + ClientCredentialsOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); + builderConsumer.accept(builder); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code client_credentials} grant. + */ + public class ClientCredentialsGrantBuilder implements Builder { + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + + private ClientCredentialsGrantBuilder() { + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder accessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Builds an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider}. + * + * @return the {@link ClientCredentialsOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + return authorizedClientProvider; + } + } + + /** + * Builds an instance of {@link DelegatingOAuth2AuthorizedClientProvider} + * composed of one or more {@link OAuth2AuthorizedClientProvider}(s). + * + * @return the {@link DelegatingOAuth2AuthorizedClientProvider} + */ + public OAuth2AuthorizedClientProvider build() { + List authorizedClientProviders = + this.builders.values().stream() + .map(Builder::build) + .collect(Collectors.toList()); + return new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders); + } + + interface Builder { + OAuth2AuthorizedClientProvider build(); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..36046118948 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + * @see DefaultRefreshTokenTokenResponseClient + */ +public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private OAuth2AccessTokenResponseClient accessTokenResponseClient = + new DefaultRefreshTokenTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + + /** + * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Returns {@code null} if re-authorization is not supported, + * e.g. the client is not authorized OR the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} + * is not available for the authorized client OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * + *

+ * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + *

    + *
  1. {@link OAuth2AuthorizationContext#REQUEST_SCOPE_ATTRIBUTE_NAME} (optional) - a {@code String[]} of scope(s) + * to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
  2. + *
+ * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if re-authorization is not supported + */ + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient == null || + authorizedClient.getRefreshToken() == null || + !hasTokenExpired(authorizedClient.getAccessToken())) { + return null; + } + + Object requestScope = context.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + Set scopes = Collections.emptySet(); + if (requestScope != null) { + Assert.isInstanceOf(String[].class, requestScope, + "The context attribute must be of type String[] '" + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); + } + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(), + authorizedClient.getRefreshToken(), scopes); + OAuth2AccessTokenResponse tokenResponse = + this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + + return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(), + context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + } + + private boolean hasTokenExpired(AbstractOAuth2Token token) { + return token.getExpiresAt().isBefore(Instant.now().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant + */ + public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java new file mode 100644 index 00000000000..0efd37d8ebd --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; + +import java.util.Arrays; + +/** + * The default implementation of an {@link OAuth2AccessTokenResponseClient} + * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. + * This implementation uses a {@link RestOperations} when requesting + * an access token credential at the Authorization Server's Token Endpoint. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AccessTokenResponseClient + * @see OAuth2RefreshTokenGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section 6 Refreshing an Access Token + */ +public final class DefaultRefreshTokenTokenResponseClient implements OAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + + private Converter> requestEntityConverter = + new OAuth2RefreshTokenGrantRequestEntityConverter(); + + private RestOperations restOperations; + + public DefaultRefreshTokenTokenResponseClient() { + RestTemplate restTemplate = new RestTemplate(Arrays.asList( + new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + this.restOperations = restTemplate; + } + + @Override + public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + Assert.notNull(refreshTokenGrantRequest, "refreshTokenGrantRequest cannot be null"); + + RequestEntity request = this.requestEntityConverter.convert(refreshTokenGrantRequest); + + ResponseEntity response; + try { + response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + + OAuth2AccessTokenResponse tokenResponse = response.getBody(); + + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes()) || + tokenResponse.getRefreshToken() == null) { + OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse.withResponse(tokenResponse); + + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { + // As per spec, in Section 5.1 Successful Access Token Response + // https://tools.ietf.org/html/rfc6749#section-5.1 + // If AccessTokenResponse.scope is empty, then default to the scope + // originally requested by the client in the Token Request + tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes()); + } + + if (tokenResponse.getRefreshToken() == null) { + // Reuse existing refresh token + tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue()); + } + + tokenResponse = tokenResponseBuilder.build(); + } + + return tokenResponse; + } + + /** + * Sets the {@link Converter} used for converting the {@link OAuth2RefreshTokenGrantRequest} + * to a {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. + * + * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the Access Token Request + */ + public void setRequestEntityConverter(Converter> requestEntityConverter) { + Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); + this.requestEntityConverter = requestEntityConverter; + } + + /** + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token Response. + * + *

+ * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + *

    + *
  1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and {@link OAuth2AccessTokenResponseHttpMessageConverter}
  2. + *
  3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
  4. + *
+ * + * @param restOperations the {@link RestOperations} used when requesting the Access Token Response + */ + public void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java new file mode 100644 index 00000000000..a93b76fdd41 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.util.Assert; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +/** + * An OAuth 2.0 Refresh Token Grant request that holds the {@link OAuth2RefreshToken refresh token} credential + * granted to the {@link #getClientRegistration() client}. + * + * @author Joe Grandja + * @since 5.2 + * @see AbstractOAuth2AuthorizationGrantRequest + * @see OAuth2RefreshToken + * @see Section 6 Refreshing an Access Token + */ +public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + private final ClientRegistration clientRegistration; + private final OAuth2AccessToken accessToken; + private final OAuth2RefreshToken refreshToken; + private final Set scopes; + + /** + * Constructs an {@code OAuth2RefreshTokenGrantRequest} using the provided parameters. + * + * @param clientRegistration the authorized client's registration + * @param accessToken the access token credential granted + * @param refreshToken the refresh token credential granted + */ + public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, + OAuth2RefreshToken refreshToken) { + this(clientRegistration, accessToken, refreshToken, Collections.emptySet()); + } + + /** + * Constructs an {@code OAuth2RefreshTokenGrantRequest} using the provided parameters. + * + * @param clientRegistration the authorized client's registration + * @param accessToken the access token credential granted + * @param refreshToken the refresh token credential granted + * @param scopes the scopes to request + */ + public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, + OAuth2RefreshToken refreshToken, Set scopes) { + super(AuthorizationGrantType.REFRESH_TOKEN); + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + Assert.notNull(accessToken, "accessToken cannot be null"); + Assert.notNull(refreshToken, "refreshToken cannot be null"); + this.clientRegistration = clientRegistration; + this.accessToken = accessToken; + this.refreshToken = refreshToken; + this.scopes = Collections.unmodifiableSet(scopes != null ? + new LinkedHashSet<>(scopes) : Collections.emptySet()); + } + + /** + * Returns the authorized client's {@link ClientRegistration registration}. + * + * @return the {@link ClientRegistration} + */ + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + + /** + * Returns the {@link OAuth2AccessToken access token} credential granted. + * + * @return the {@link OAuth2AccessToken} + */ + public OAuth2AccessToken getAccessToken() { + return this.accessToken; + } + + /** + * Returns the {@link OAuth2RefreshToken refresh token} credential granted. + * + * @return the {@link OAuth2RefreshToken} + */ + public OAuth2RefreshToken getRefreshToken() { + return this.refreshToken; + } + + /** + * Returns the scope(s) to request. + * + * @return the scope(s) to request + */ + public Set getScopes() { + return this.scopes; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java new file mode 100644 index 00000000000..00cac8beed0 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; + +import java.net.URI; + +/** + * A {@link Converter} that converts the provided {@link OAuth2RefreshTokenGrantRequest} + * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request + * for the Refresh Token Grant. + * + * @author Joe Grandja + * @since 5.2 + * @see Converter + * @see OAuth2RefreshTokenGrantRequest + * @see RequestEntity + */ +public class OAuth2RefreshTokenGrantRequestEntityConverter implements Converter> { + + /** + * Returns the {@link RequestEntity} used for the Access Token Request. + * + * @param refreshTokenGrantRequest the refresh token grant request + * @return the {@link RequestEntity} used for the Access Token Request + */ + @Override + public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); + + HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); + MultiValueMap formParameters = buildFormParameters(refreshTokenGrantRequest); + URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()) + .build() + .toUri(); + + return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); + } + + /** + * Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body. + * + * @param refreshTokenGrantRequest the refresh token grant request + * @return a {@link MultiValueMap} of the form parameters used for the Access Token Request body + */ + private MultiValueMap buildFormParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); + + MultiValueMap formParameters = new LinkedMultiValueMap<>(); + formParameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); + formParameters.add(OAuth2ParameterNames.REFRESH_TOKEN, + refreshTokenGrantRequest.getRefreshToken().getTokenValue()); + if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) { + formParameters.add(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " ")); + } + if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { + formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + + return formParameters; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java new file mode 100644 index 00000000000..d3644186136 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +/** + * The default implementation of an {@link OAuth2AuthorizedClientManager}. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientManager + * @see OAuth2AuthorizedClientProvider + */ +public final class DefaultOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null; + private Function> contextAttributesMapper = new DefaultContextAttributesMapper(); + + /** + * Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public DefaultOAuth2AuthorizedClientManager(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + @Nullable + @Override + public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { + Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); + + String clientRegistrationId = authorizeRequest.getClientRegistrationId(); + OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); + Authentication principal = authorizeRequest.getPrincipal(); + HttpServletRequest servletRequest = authorizeRequest.getServletRequest(); + HttpServletResponse servletResponse = authorizeRequest.getServletResponse(); + + OAuth2AuthorizationContext.Builder contextBuilder; + if (authorizedClient != null) { + contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); + } else { + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, principal, servletRequest); + if (authorizedClient != null) { + contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); + } else { + contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration); + } + } + OAuth2AuthorizationContext authorizationContext = contextBuilder + .principal(principal) + .attributes(this.contextAttributesMapper.apply(authorizeRequest)) + .build(); + + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + if (authorizedClient != null) { + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, servletRequest, servletResponse); + } else { + // In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported. + // For these cases, return the provided `authorizationContext.authorizedClient`. + if (authorizationContext.getAuthorizedClient() != null) { + return authorizationContext.getAuthorizedClient(); + } + } + + return authorizedClient; + } + + /** + * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * + * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + */ + public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.authorizedClientProvider = authorizedClientProvider; + } + + /** + * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes + * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * + * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes + * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + */ + public void setContextAttributesMapper(Function> contextAttributesMapper) { + Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); + this.contextAttributesMapper = contextAttributesMapper; + } + + /** + * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + */ + public static class DefaultContextAttributesMapper implements Function> { + + @Override + public Map apply(OAuth2AuthorizeRequest authorizeRequest) { + Map contextAttributes = Collections.emptyMap(); + String scope = authorizeRequest.getServletRequest().getParameter(OAuth2ParameterNames.SCOPE); + if (StringUtils.hasText(scope)) { + contextAttributes = new HashMap<>(); + contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, + StringUtils.delimitedListToStringArray(scope, " ")); + } + return contextAttributes; + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java new file mode 100644 index 00000000000..7f221183855 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Represents a request the {@link OAuth2AuthorizedClientManager} uses to + * {@link OAuth2AuthorizedClientManager#authorize(OAuth2AuthorizeRequest) authorize} (or re-authorize) + * the {@link ClientRegistration client} identified by the provided {@link #getClientRegistrationId() clientRegistrationId}. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientManager + */ +public class OAuth2AuthorizeRequest { + private final String clientRegistrationId; + private final OAuth2AuthorizedClient authorizedClient; + private final Authentication principal; + private final HttpServletRequest servletRequest; + private final HttpServletResponse servletResponse; + + /** + * Constructs an {@code OAuth2AuthorizeRequest} using the provided parameters. + * + * @param clientRegistrationId the identifier for the {@link ClientRegistration client registration} + * @param principal the {@code Principal} (to be) associated to the authorized client + * @param servletRequest the {@code HttpServletRequest} + * @param servletResponse the {@code HttpServletResponse} + */ + public OAuth2AuthorizeRequest(String clientRegistrationId, Authentication principal, + HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(servletRequest, "servletRequest cannot be null"); + Assert.notNull(servletResponse, "servletResponse cannot be null"); + this.clientRegistrationId = clientRegistrationId; + this.authorizedClient = null; + this.principal = principal; + this.servletRequest = servletRequest; + this.servletResponse = servletResponse; + } + + /** + * Constructs an {@code OAuth2AuthorizeRequest} using the provided parameters. + * + * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} + * @param principal the {@code Principal} associated to the authorized client + * @param servletRequest the {@code HttpServletRequest} + * @param servletResponse the {@code HttpServletResponse} + */ + public OAuth2AuthorizeRequest(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(servletRequest, "servletRequest cannot be null"); + Assert.notNull(servletResponse, "servletResponse cannot be null"); + this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId(); + this.authorizedClient = authorizedClient; + this.principal = principal; + this.servletRequest = servletRequest; + this.servletResponse = servletResponse; + } + + /** + * Returns the identifier for the {@link ClientRegistration client registration}. + * + * @return the identifier for the client registration + */ + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + /** + * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it was not provided. + * + * @return the {@link OAuth2AuthorizedClient} or {@code null} if it was not provided + */ + @Nullable + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + /** + * Returns the {@code Principal} (to be) associated to the authorized client. + * + * @return the {@code Principal} (to be) associated to the authorized client + */ + public Authentication getPrincipal() { + return this.principal; + } + + /** + * Returns the {@code HttpServletRequest}. + * + * @return the {@code HttpServletRequest} + */ + public HttpServletRequest getServletRequest() { + return this.servletRequest; + } + + /** + * Returns the {@code HttpServletResponse}. + * + * @return the {@code HttpServletResponse} + */ + public HttpServletResponse getServletResponse() { + return this.servletResponse; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java new file mode 100644 index 00000000000..af90c1600e0 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ClientRegistration; + +/** + * Implementations of this interface are responsible for the overall management + * of {@link OAuth2AuthorizedClient Authorized Client(s)}. + * + *

+ * The primary responsibilities include: + *

    + *
  1. Authorizing (or re-authorizing) an OAuth 2.0 Client + * by leveraging an {@link OAuth2AuthorizedClientProvider}(s).
  2. + *
  3. Managing the persistence of an {@link OAuth2AuthorizedClient} between requests, + * typically using an {@link OAuth2AuthorizedClientRepository}.
  4. + *
+ * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientProvider + * @see OAuth2AuthorizedClientRepository + */ +public interface OAuth2AuthorizedClientManager { + + /** + * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration client} + * identified by the provided {@link OAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. + * Implementations must return {@code null} if authorization is not supported for the specified client, + * e.g. the associated {@link OAuth2AuthorizedClientProvider}(s) does not support + * the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. + * + *

+ * In the case of re-authorization, implementations must return the provided {@link OAuth2AuthorizeRequest#getAuthorizedClient() authorized client} + * if re-authorization is not supported for the client OR is not required, + * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR + * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * + * @param authorizeRequest the authorize request + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported for the specified client + */ + @Nullable + OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest); + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 7a61e319c65..de931ba5218 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,20 +19,23 @@ import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.bind.support.WebDataBinderFactory; @@ -64,10 +67,21 @@ * @see RegisteredOAuth2AuthorizedClient */ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { - private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientRepository authorizedClientRepository; - private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = - new DefaultClientCredentialsTokenResponseClient(); + private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( + "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + private OAuth2AuthorizedClientManager authorizedClientManager; + private boolean defaultAuthorizedClientManager; + + /** + * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. + * + * @since 5.2 + * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) + */ + public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager authorizedClientManager) { + Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); + this.authorizedClientManager = authorizedClientManager; + } /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. @@ -79,8 +93,24 @@ public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clien OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; + this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + this.defaultAuthorizedClientManager = true; + } + + private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( + ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; } @Override @@ -106,29 +136,16 @@ public Object resolveArgument(MethodParameter parameter, } Authentication principal = SecurityContextHolder.getContext().getAuthentication(); - HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, principal, servletRequest); - if (authorizedClient != null) { - return authorizedClient; - } - - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - if (clientRegistration == null) { - return null; - } - - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - throw new ClientAuthorizationRequiredException(clientRegistrationId); + if (principal == null) { + principal = ANONYMOUS_AUTHENTICATION; } + HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); + HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { - HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - authorizedClient = this.authorizeClientCredentialsClient(clientRegistration, servletRequest, servletResponse); - } + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + clientRegistrationId, principal, servletRequest, servletResponse); - return authorizedClient; + return this.authorizedClientManager.authorize(authorizeRequest); } private String resolveClientRegistrationId(MethodParameter parameter) { @@ -149,37 +166,34 @@ private String resolveClientRegistrationId(MethodParameter parameter) { return clientRegistrationId; } - private OAuth2AuthorizedClient authorizeClientCredentialsClient(ClientRegistration clientRegistration, - HttpServletRequest request, HttpServletResponse response) { - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(clientRegistration); - OAuth2AccessTokenResponse tokenResponse = - this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - - Authentication principal = SecurityContextHolder.getContext().getAuthentication(); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, - (principal != null ? principal.getName() : "anonymousUser"), - tokenResponse.getAccessToken()); - - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - principal, - request, - response); - - return authorizedClient; - } - /** * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. * + * @deprecated Use {@link #OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)} instead. + * Create an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a + * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} + * (or a custom one) and than supply it to {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}. + * * @param clientCredentialsTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant */ + @Deprecated public final void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + updateDefaultAuthorizedClientManager(clientCredentialsTokenResponseClient); + } + + private void updateDefaultAuthorizedClientManager( + OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) + .build(); + ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 0054ca164c5..9919bf859f7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -19,29 +19,28 @@ import org.reactivestreams.Subscription; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; -import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; @@ -51,22 +50,15 @@ import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; -import reactor.core.scheduler.Schedulers; import reactor.util.context.Context; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.net.URI; -import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.Collection; import java.util.Map; -import java.util.Optional; import java.util.function.Consumer; -import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; - /** * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the * token as a Bearer Token. It also provides mechanisms for looking up the {@link OAuth2AuthorizedClient}. This class is @@ -75,7 +67,7 @@ * Example usage: * *

- * OAuth2AuthorizedClientExchangeFilterFunction oauth2 = new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService);
+ * ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository);
  * WebClient webClient = WebClient.builder()
  *    .apply(oauth2.oauth2Configuration())
  *    .build();
@@ -92,18 +84,18 @@
  * are true:
  *
  * 
    - *
  • The ReactiveOAuth2AuthorizedClientService on the - * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction} is not null
  • - *
  • A refresh token is present on the OAuth2AuthorizedClient
  • - *
  • The access token will be expired in - * {@link #setAccessTokenExpiresSkew(Duration)}
  • - *
  • The {@link ReactiveSecurityContextHolder} will be used to attempt to save - * the token. If it is empty, then the principal name on the OAuth2AuthorizedClient + *
  • The {@link OAuth2AuthorizedClientManager} is not null
  • + *
  • A refresh token is present on the {@link OAuth2AuthorizedClient}
  • + *
  • The access token is expired
  • + *
  • The {@link SecurityContextHolder} will be used to attempt to save + * the token. If it is empty, then the principal name on the {@link OAuth2AuthorizedClient} * will be used to create an Authentication for saving.
  • *
* * @author Rob Winch + * @author Joe Grandja * @since 5.1 + * @see OAuth2AuthorizedClientManager */ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction, InitializingBean, DisposableBean { @@ -119,16 +111,18 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); - private Clock clock = Clock.systemUTC(); + private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( + "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + @Deprecated private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); - private ClientRegistrationRepository clientRegistrationRepository; + @Deprecated + private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; - private OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientManager authorizedClientManager; - private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = - new DefaultClientCredentialsTokenResponseClient(); + private boolean defaultAuthorizedClientManager; private boolean defaultOAuth2AuthorizedClient; @@ -137,11 +131,44 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction public ServletOAuth2AuthorizedClientExchangeFilterFunction() { } + /** + * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * + * @since 5.2 + * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) + */ + public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager authorizedClientManager) { + Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); + this.authorizedClientManager = authorizedClientManager; + } + + /** + * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ public ServletOAuth2AuthorizedClientExchangeFilterFunction( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; + this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + this.defaultAuthorizedClientManager = true; + } + + private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( + ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; } @Override @@ -155,14 +182,40 @@ public void destroy() throws Exception { } /** - * Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for - * client_credentials grant. + * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant. + * + * @deprecated Use {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} instead. + * Create an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a + * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} + * (or a custom one) and than supply it to {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}. + * * @param clientCredentialsTokenResponseClient the client to use */ + @Deprecated public void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); + Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + updateDefaultAuthorizedClientManager(); + } + + private void updateDefaultAuthorizedClientManager() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) + .clientCredentials(this::updateClientCredentialsProvider) + .build(); + ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); + } + + private void updateClientCredentialsProvider(OAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) { + if (this.clientCredentialsTokenResponseClient != null) { + builder.accessTokenResponseClient(this.clientCredentialsTokenResponseClient); + } + builder.clockSkew(this.accessTokenExpiresSkew); } /** @@ -176,7 +229,6 @@ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClie this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; } - /** * If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is * recommended to be cautious with this feature since all HTTP requests will receive the access token. @@ -279,11 +331,20 @@ public static Consumer> httpServletResponse(HttpServletRespo /** * An access token will be considered expired by comparing its expiration to now + * this skewed Duration. The default is 1 minute. + * + * @deprecated The {@code accessTokenExpiresSkew} should be configured with the specific {@link OAuth2AuthorizedClientProvider} implementation, + * e.g. {@link ClientCredentialsOAuth2AuthorizedClientProvider#setClockSkew(Duration) ClientCredentialsOAuth2AuthorizedClientProvider} or + * {@link RefreshTokenOAuth2AuthorizedClientProvider#setClockSkew(Duration) RefreshTokenOAuth2AuthorizedClientProvider}. + * * @param accessTokenExpiresSkew the Duration to use. */ + @Deprecated public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); + Assert.state(this.defaultAuthorizedClientManager, "The accessTokenExpiresSkew cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); this.accessTokenExpiresSkew = accessTokenExpiresSkew; + updateDefaultAuthorizedClientManager(); } @Override @@ -292,7 +353,7 @@ public Mono filter(ClientRequest request, ExchangeFunction next) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) .switchIfEmpty(mergeRequestAttributesFromContext(request)) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) - .flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes()))) + .flatMap(req -> authorizedClient(getOAuth2AuthorizedClient(req.attributes()), req)) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); @@ -319,8 +380,8 @@ private void populateRequestAttributes(Map attrs, Context ctx) { } private void populateDefaultRequestResponse(Map attrs) { - if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey( - HTTP_SERVLET_RESPONSE_ATTR_NAME)) { + if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && + attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { return; } ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); @@ -343,8 +404,8 @@ private void populateDefaultAuthentication(Map attrs) { } private void populateDefaultOAuth2AuthorizedClient(Map attrs) { - if (this.authorizedClientRepository == null - || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) { + if (this.authorizedClientManager == null || + attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) { return; } @@ -360,116 +421,30 @@ private void populateDefaultOAuth2AuthorizedClient(Map attrs) { } if (clientRegistrationId != null) { HttpServletRequest request = getRequest(attrs); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository - .loadAuthorizedClient(clientRegistrationId, authentication, - request); - if (authorizedClient == null) { - authorizedClient = getAuthorizedClient(clientRegistrationId, attrs); + if (authentication == null) { + authentication = ANONYMOUS_AUTHENTICATION; } + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + clientRegistrationId, authentication, request, getResponse(attrs)); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); oauth2AuthorizedClient(authorizedClient).accept(attrs); } } - private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map attrs) { - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - if (clientRegistration == null) { - throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); - } - if (isClientCredentialsGrantType(clientRegistration)) { - return authorizeWithClientCredentials(clientRegistration, attrs); - } - throw new ClientAuthorizationRequiredException(clientRegistrationId); - } - - private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) { - return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()); - } - - private OAuth2AuthorizedClient authorizeWithClientCredentials( - ClientRegistration clientRegistration, Map attrs) { - HttpServletRequest request = getRequest(attrs); - HttpServletResponse response = getResponse(attrs); - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(clientRegistration); - OAuth2AccessTokenResponse tokenResponse = - this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - - Authentication principal = getAuthentication(attrs); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, - (principal != null ? principal.getName() : "anonymousUser"), - tokenResponse.getAccessToken()); - - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - principal, - request, - response); - - return authorizedClient; - } - - private Mono authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { - ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); - if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) { - // Client credentials grant do not have refresh tokens but can expire so we need to get another one - return Mono.fromSupplier(() -> authorizeWithClientCredentials(clientRegistration, request.attributes())); - } else if (shouldRefreshToken(authorizedClient)) { - return authorizeWithRefreshToken(request, next, authorizedClient); + private Mono authorizedClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) { + if (this.authorizedClientManager == null) { + return Mono.just(authorizedClient); } - return Mono.just(authorizedClient); - } - - private Mono authorizeWithRefreshToken(ClientRequest request, ExchangeFunction next, - OAuth2AuthorizedClient authorizedClient) { - ClientRegistration clientRegistration = authorizedClient - .getClientRegistration(); - String tokenUri = clientRegistration - .getProviderDetails().getTokenUri(); - ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri)) - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret())) - .body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue())) - .build(); - return next.exchange(refreshRequest) - .flatMap(response -> response.body(oauth2AccessTokenResponse())) - .map(accessTokenResponse -> { - OAuth2RefreshToken refreshToken = Optional.ofNullable(accessTokenResponse.getRefreshToken()) - .orElse(authorizedClient.getRefreshToken()); - return new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), refreshToken); - }) - .map(result -> { - Authentication principal = (Authentication) request.attribute( - AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(authorizedClient.getPrincipalName())); - HttpServletRequest httpRequest = (HttpServletRequest) request.attributes().get( - HTTP_SERVLET_REQUEST_ATTR_NAME); - HttpServletResponse httpResponse = (HttpServletResponse) request.attributes().get( - HTTP_SERVLET_RESPONSE_ATTR_NAME); - this.authorizedClientRepository.saveAuthorizedClient(result, principal, httpRequest, httpResponse); - return result; - }) - .publishOn(Schedulers.elastic()); - } - - private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) { - if (this.authorizedClientRepository == null) { - return false; - } - OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); - if (refreshToken == null) { - return false; - } - return hasTokenExpired(authorizedClient); - } - - private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { - Instant now = this.clock.instant(); - Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); - if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) { - return true; + Map attrs = request.attributes(); + Authentication authentication = getAuthentication(attrs); + if (authentication == null) { + authentication = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); } - return false; + HttpServletRequest servletRequest = getRequest(attrs); + HttpServletResponse servletResponse = getResponse(attrs); + OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( + authorizedClient, authentication, servletRequest, servletResponse); + return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)); } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { @@ -491,12 +466,6 @@ private CoreSubscriber createRequestContextSubscriber(CoreSubscriber d return new RequestContextSubscriber<>(delegate, request, response, authentication); } - private static BodyInserters.FormInserter refreshTokenBody(String refreshToken) { - return BodyInserters - .fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) - .with("refresh_token", refreshToken); - } - static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map attrs) { return (OAuth2AuthorizedClient) attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME); } @@ -518,10 +487,11 @@ static HttpServletResponse getResponse(Map attrs) { } private static class PrincipalNameAuthentication implements Authentication { - private final String username; + private final String principalName; - private PrincipalNameAuthentication(String username) { - this.username = username; + private PrincipalNameAuthentication(String principalName) { + Assert.hasText(principalName, "principalName cannot be empty"); + this.principalName = principalName; } @Override @@ -541,7 +511,7 @@ public Object getDetails() { @Override public Object getPrincipal() { - throw unsupported(); + return getName(); } @Override @@ -550,14 +520,13 @@ public boolean isAuthenticated() { } @Override - public void setAuthenticated(boolean isAuthenticated) - throws IllegalArgumentException { + public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException { throw unsupported(); } @Override public String getName() { - return this.username; + return this.principalName; } private UnsupportedOperationException unsupported() { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..c393b9f3235 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class AuthorizationCodeOAuth2AuthorizedClientProviderTests { + private AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + + @Before + public void setup() { + this.authorizedClientProvider = new AuthorizationCodeOAuth2AuthorizedClientProvider(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { + ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(clientCredentialsClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..10acb7cdac8 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,149 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; + +import java.time.Duration; +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link ClientCredentialsOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class ClientCredentialsOAuth2AuthorizedClientProviderTests { + private ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; + + @Before + public void setup() { + this.authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + this.clientRegistration = TestClientRegistrations.clientCredentials().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessTokenResponseClient cannot be null"); + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew cannot be null"); + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew must be >= 0"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + + @Test + public void authorizeWhenClientCredentialsAndTokenExpiredThenReauthorize() { + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), accessToken); + + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + + @Test + public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..f930233aa83 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DelegatingOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class DelegatingOAuth2AuthorizedClientProviderTests { + + @Test + public void constructorWhenProvidersIsEmptyThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DelegatingOAuth2AuthorizedClientProvider(new OAuth2AuthorizedClientProvider[0])) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new DelegatingOAuth2AuthorizedClientProvider(Collections.emptyList())) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + mock(OAuth2AuthorizedClientProvider.class)); + assertThatThrownBy(() -> delegate.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { + Authentication principal = new TestingAuthenticationToken("principal", "password"); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, principal.getName(), TestOAuth2AccessTokens.noScopes()); + + OAuth2AuthorizedClientProvider authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); + when(authorizedClientProvider.authorize(any())).thenReturn(authorizedClient); + + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class), authorizedClientProvider); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .principal(principal) + .build(); + OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context); + assertThat(reauthorizedClient).isSameAs(authorizedClient); + } + + @Test + public void authorizeWhenProviderCantAuthorizeThenReturnNull() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .principal(new TestingAuthenticationToken("principal", "password")) + .build(); + + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class)); + assertThat(delegate.authorize(context)).isNull(); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java new file mode 100644 index 00000000000..89236d4c4ff --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import static org.assertj.core.api.Assertions.*; + +/** + * Tests for {@link OAuth2AuthorizationContext}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizationContextTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + + @Before + public void setup() { + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void forClientWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration((ClientRegistration) null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistration cannot be null"); + } + + @Test + public void forClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.withAuthorizedClient((OAuth2AuthorizedClient) null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + } + + @Test + public void forClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + } + + @Test + public void forClientWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute("attribute1", "value1") + .attribute("attribute2", "value2") + .build(); + assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); + assertThat(authorizationContext.getAttributes()).contains( + entry("attribute1", "value1"), entry("attribute2", "value2")); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java new file mode 100644 index 00000000000..cbc1880877c --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java @@ -0,0 +1,202 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpStatus; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.web.client.RestOperations; + +import java.time.Duration; +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link OAuth2AuthorizedClientProviderBuilder}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizedClientProviderBuilderTests { + private RestOperations accessTokenClient; + private DefaultClientCredentialsTokenResponseClient clientCredentialsTokenResponseClient; + private DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient; + private Authentication principal; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + this.accessTokenClient = mock(RestOperations.class); + when(this.accessTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .thenReturn(new ResponseEntity(accessTokenResponse, HttpStatus.OK)); + this.refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + this.refreshTokenTokenResponseClient.setRestOperations(this.accessTokenClient); + this.clientCredentialsTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + this.clientCredentialsTokenResponseClient.setRestOperations(this.accessTokenClient); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void providerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizedClientProviderBuilder.builder().provider(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientRegistration().build()) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } + + @Test + public void buildWhenRefreshTokenProviderThenProviderReauthorizes() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .build(); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + TestClientRegistrations.clientRegistration().build(), + this.principal.getName(), + expiredAccessToken(), + TestOAuth2RefreshTokens.refreshToken()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext); + + assertThat(reauthorizedClient).isNotNull(); + verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + + @Test + public void buildWhenClientCredentialsProviderThenProviderAuthorizes() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientCredentials().build()) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); + + assertThat(authorizedClient).isNotNull(); + verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + + @Test + public void buildWhenAllProvidersThenProvidersAuthorize() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .build(); + + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + + + // authorization_code + OAuth2AuthorizationContext authorizationCodeContext = + OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + + + // refresh_token + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, + this.principal.getName(), + expiredAccessToken(), + TestOAuth2RefreshTokens.refreshToken()); + + OAuth2AuthorizationContext refreshTokenContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(refreshTokenContext); + + assertThat(reauthorizedClient).isNotNull(); + verify(this.accessTokenClient, times(1)).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + + + // client_credentials + OAuth2AuthorizationContext clientCredentialsContext = + OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientCredentials().build()) + .principal(this.principal) + .build(); + authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext); + + assertThat(authorizedClient).isNotNull(); + verify(this.accessTokenClient, times(2)).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + + @Test + public void buildWhenCustomProviderThenProviderCalled() { + OAuth2AuthorizedClientProvider customProvider = mock(OAuth2AuthorizedClientProvider.class); + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .provider(customProvider) + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientRegistration().build()) + .principal(this.principal) + .build(); + authorizedClientProvider.authorize(authorizationContext); + + verify(customProvider).authorize(any(OAuth2AuthorizationContext.class)); + } + + private OAuth2AccessToken expiredAccessToken() { + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..06124d06b08 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,187 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link RefreshTokenOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class RefreshTokenOAuth2AuthorizedClientProviderTests { + private RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + + @Before + public void setup() { + this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + expiredAccessToken, TestOAuth2RefreshTokens.refreshToken()); + } + + @Test + public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessTokenResponseClient cannot be null"); + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew cannot be null"); + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew must be >= 0"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); + } + + @Test + public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + String[] requestScope = new String[] { "read", "write" }; + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) + .build(); + + this.authorizedClientProvider.authorize(authorizationContext); + + ArgumentCaptor refreshTokenGrantRequestArgCaptor = + ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); + assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(new HashSet<>(Arrays.asList(requestScope))); + } + + @Test + public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { + String invalidRequestScope = "read write"; + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) + .build(); + + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("The context attribute must be of type String[] '" + + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java new file mode 100644 index 00000000000..5902eb4543f --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java @@ -0,0 +1,221 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; + +import java.time.Instant; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link DefaultRefreshTokenTokenResponseClient}. + * + * @author Joe Grandja + */ +public class DefaultRefreshTokenTokenResponseClientTests { + private DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; + private MockWebServer server; + + @Before + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); + this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); + } + + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("grant_type=refresh_token"); + assertThat(formParameters).contains("refresh_token=refresh-token"); + + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(this.accessToken.getScopes().toArray(new String[0])); + assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.refreshToken.getTokenValue()); + } + + @Test + public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .clientAuthenticationMethod(ClientAuthenticationMethod.POST) + .build(); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); + + this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-id"); + assertThat(formParameters).contains("client_secret=client-secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .hasMessageContaining("tokenType cannot be null"); + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, Collections.singleton("read")); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("scope=read"); + + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("[unauthorized_client]"); + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(500)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessage("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: 500 Server Error"); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java new file mode 100644 index 00000000000..2f73174039f --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.MultiValueMap; + +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; + +/** + * Tests for {@link OAuth2RefreshTokenGrantRequestEntityConverter}. + * + * @author Joe Grandja + */ +public class OAuth2RefreshTokenGrantRequestEntityConverterTests { + private OAuth2RefreshTokenGrantRequestEntityConverter converter = new OAuth2RefreshTokenGrantRequestEntityConverter(); + private OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest; + + @Before + public void setup() { + this.refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.scopes("read", "write"), + TestOAuth2RefreshTokens.refreshToken(), + Collections.singleton("read")); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenGrantRequestValidThenConverts() { + RequestEntity requestEntity = this.converter.convert(this.refreshTokenGrantRequest); + + ClientRegistration clientRegistration = this.refreshTokenGrantRequest.getClientRegistration(); + OAuth2RefreshToken refreshToken = this.refreshTokenGrantRequest.getRefreshToken(); + + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); + assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( + clientRegistration.getProviderDetails().getTokenUri()); + + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getContentType()).isEqualTo( + MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( + AuthorizationGrantType.REFRESH_TOKEN.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN)).isEqualTo( + refreshToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read"); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java new file mode 100644 index 00000000000..dc90a388a5c --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2RefreshTokenGrantRequest}. + * + * @author Joe Grandja + */ +public class OAuth2RefreshTokenGrantRequestTests { + private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; + + @Before + public void setup() { + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); + } + + @Test + public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(null, this.accessToken, this.refreshToken)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistration cannot be null"); + } + + @Test + public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.clientRegistration, null, this.refreshToken)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessToken cannot be null"); + } + + @Test + public void constructorWhenRefreshTokenIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.clientRegistration, this.accessToken, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("refreshToken cannot be null"); + } + + @Test + public void constructorWhenValidParametersProvidedThenCreated() { + Set scopes = new HashSet<>(Arrays.asList("read", "write")); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistration, this.accessToken, this.refreshToken, scopes); + assertThat(refreshTokenGrantRequest.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(refreshTokenGrantRequest.getAccessToken()).isSameAs(this.accessToken); + assertThat(refreshTokenGrantRequest.getRefreshToken()).isSameAs(this.refreshToken); + assertThat(refreshTokenGrantRequest.getScopes()).isEqualTo(scopes); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java new file mode 100644 index 00000000000..1d200fc323d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java @@ -0,0 +1,282 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; + +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link DefaultOAuth2AuthorizedClientManager}. + * + * @author Joe Grandja + */ +public class DefaultOAuth2AuthorizedClientManagerTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientProvider authorizedClientProvider; + private Function contextAttributesMapper; + private DefaultOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private MockHttpServletRequest request; + private MockHttpServletResponse response; + private ArgumentCaptor authorizationContextCaptor; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); + this.contextAttributesMapper = mock(Function.class); + this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); + this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientManager(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + + @Test + public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientProvider cannot be null"); + } + + @Test + public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("contextAttributesMapper cannot be null"); + } + + @Test + public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizeRequest cannot be null"); + } + + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + "invalid-registration-id", this.principal, this.request, this.response); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isNull(); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient); + + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(this.authorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.request))).thenReturn(this.authorizedClient); + + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(any()); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { + OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenSupportedProviderThenReauthorized() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenRequestScopeParameterThenMappedToContext() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + // Override the mock with the default + this.authorizedClientManager.setContextAttributesMapper( + new DefaultOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); + + this.request.addParameter(OAuth2ParameterNames.SCOPE, "read write"); + + OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + assertThat(requestScopeAttribute).contains("read", "write"); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java new file mode 100644 index 00000000000..6d7e687fcdd --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2AuthorizeRequest}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizeRequestTests { + private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + private Authentication principal = new TestingAuthenticationToken("principal", "password"); + private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + private MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + private MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + + @Test + public void constructorWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest((String) null, this.principal, this.servletRequest, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationId cannot be empty"); + } + + @Test + public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest((OAuth2AuthorizedClient) null, this.principal, this.servletRequest, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + } + + @Test + public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), null, this.servletRequest, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + } + + @Test + public void constructorWhenServletRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, null, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("servletRequest cannot be null"); + } + + @Test + public void constructorWhenServletResponseIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, this.servletRequest, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("servletResponse cannot be null"); + } + + @Test + public void constructorClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.servletRequest, this.servletResponse); + + assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); + assertThat(authorizeRequest.getAuthorizedClient()).isNull(); + assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); + assertThat(authorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); + } + + @Test + public void constructorAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.servletRequest, this.servletResponse); + + assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.authorizedClient.getClientRegistration().getRegistrationId()); + assertThat(authorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient); + assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); + assertThat(authorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 508c3eda454..d5f7094e31e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,19 +20,25 @@ import org.junit.Test; import org.springframework.core.MethodParameter; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -42,6 +48,7 @@ import org.springframework.web.context.request.ServletWebRequest; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import java.lang.reflect.Method; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -67,6 +74,7 @@ public class OAuth2AuthorizedClientArgumentResolverTests { private OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientArgumentResolver argumentResolver; private MockHttpServletRequest request; + private MockHttpServletResponse response; @Before public void setup() { @@ -98,8 +106,16 @@ public void setup() { .build(); this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.registration2); this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver( + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); this.authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName, mock(OAuth2AccessToken.class)); when(this.authorizedClientRepository.loadAuthorizedClient( eq(this.registration1.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class))) @@ -109,6 +125,7 @@ public void setup() { eq(this.registration2.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class))) .thenReturn(this.authorizedClient2); this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); } @After @@ -128,10 +145,25 @@ public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllega .isInstanceOf(IllegalArgumentException.class); } + @Test + public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientCredentialsTokenResponseClient cannot be null"); + } + + @Test + public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { + assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(new DefaultClientCredentialsTokenResponseClient())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("The client cannot be set when the constructor used is \"OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); } @Test @@ -175,21 +207,22 @@ public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenRes SecurityContextHolder.setContext(securityContext); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient1); + methodParameter, null, new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); } @Test public void resolveArgumentWhenAuthorizedClientFoundThenResolves() throws Exception { MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient1); + methodParameter, null, new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); } @Test - public void resolveArgumentWhenRegistrationIdInvalidThenDoesNotResolve() throws Exception { + public void resolveArgumentWhenRegistrationIdInvalidThenThrowIllegalArgumentException() { MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid", OAuth2AuthorizedClient.class); - assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null)).isNull(); + assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id 'invalid'"); } @Test @@ -197,7 +230,7 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForAuthorizationCodeClien when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) .thenReturn(null); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null)) + assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null)) .isInstanceOf(ClientAuthorizationRequiredException.class); } @@ -206,7 +239,14 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForAuthorizationCodeClien public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClientThenResolvesFromTokenResponseClient() throws Exception { OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); - this.argumentResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient); + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = + new ClientCredentialsOAuth2AuthorizedClientProvider(); + clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(clientCredentialsAuthorizedClientProvider); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse .withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER) @@ -219,7 +259,7 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", OAuth2AuthorizedClient.class); OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null); + methodParameter, null, new ServletWebRequest(this.request, this.response), null); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration2); @@ -227,7 +267,7 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken()); verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null)); + eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), any(HttpServletResponse.class)); } private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index d99de2db281..86df1cec94f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.client.web.reactive.function.client; import org.junit.After; @@ -28,6 +27,9 @@ import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.FormHttpMessageWriter; import org.springframework.http.codec.HttpMessageWriter; @@ -45,24 +47,31 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.web.client.RestOperations; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Mono; import java.net.URI; import java.time.Duration; @@ -76,6 +85,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; @@ -95,6 +105,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; @Mock + private OAuth2AccessTokenResponseClient refreshTokenTokenResponseClient; + @Mock private WebClient.RequestHeadersSpec spec; @Captor private ArgumentCaptor>> attrs; @@ -106,14 +118,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { */ private Map result = new HashMap<>(); - private ServletOAuth2AuthorizedClientExchangeFilterFunction function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(); + private ServletOAuth2AuthorizedClientExchangeFilterFunction function; private MockExchangeFunction exchange = new MockExchangeFunction(); private Authentication authentication; - private ClientRegistration registration = TestClientRegistrations.clientRegistration() - .build(); + private ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token-0", @@ -123,6 +134,16 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Before public void setup() { this.authentication = new TestingAuthenticationToken("test", "this"); + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); } @After @@ -131,6 +152,35 @@ public void cleanup() { RequestContextHolder.resetRequestAttributes(); } + @Test + public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientCredentialsTokenResponseClient cannot be null"); + } + + @Test + public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { + assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(new DefaultClientCredentialsTokenResponseClient())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("The client cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + } + + @Test + public void setAccessTokenExpiresSkewWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { + assertThatThrownBy(() -> this.function.setAccessTokenExpiresSkew(Duration.ofSeconds(30))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("The accessTokenExpiresSkew cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + } + @Test public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() { Map attrs = getDefaultRequestAttributes(); @@ -156,8 +206,6 @@ public void defaultRequestAuthenticationWhenSecurityContextEmptyThenAuthenticati @Test public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); SecurityContextHolder.getContext().setAuthentication(this.authentication); Map attrs = getDefaultRequestAttributes(); assertThat(getAuthentication(attrs)).isEqualTo(this.authentication); @@ -166,8 +214,6 @@ public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationS @Test public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); oauth2AuthorizedClient(authorizedClient).accept(this.result); @@ -178,8 +224,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAnd @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); Map attrs = getDefaultRequestAttributes(); assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); verifyZeroInteractions(this.authorizedClientRepository); @@ -187,8 +231,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientR @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); Map attrs = getDefaultRequestAttributes(); assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); verifyZeroInteractions(this.authorizedClientRepository); @@ -208,8 +250,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenRepositoryNullThenOAuth2Auth @Test public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); this.function.setDefaultOAuth2AuthorizedClient(true); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); @@ -217,7 +257,10 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthentication OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); authentication(token).accept(this.result); + httpServletRequest(new MockHttpServletRequest()).accept(this.result); + httpServletResponse(new MockHttpServletResponse()).accept(this.result); Map attrs = getDefaultRequestAttributes(); @@ -227,8 +270,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthentication @Test public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); @@ -241,16 +282,17 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticatio @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); authentication(token).accept(this.result); clientRegistrationId("explicit").accept(this.result); + httpServletRequest(new MockHttpServletRequest()).accept(this.result); + httpServletResponse(new MockHttpServletResponse()).accept(this.result); Map attrs = getDefaultRequestAttributes(); @@ -260,12 +302,13 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegis @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); clientRegistrationId("id").accept(this.result); + httpServletRequest(new MockHttpServletRequest()).accept(this.result); + httpServletResponse(new MockHttpServletResponse()).accept(this.result); Map attrs = getDefaultRequestAttributes(); @@ -276,54 +319,53 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientR @Test public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( - accessTokenResponse); + when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + SecurityContextHolder.getContext().setAuthentication(this.authentication); + Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); + assertThat(authorizedClient.getPrincipalName()).isEqualTo("test"); assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); } @Test public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() { this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); this.function.setDefaultClientRegistrationId(this.registration.getRegistrationId()); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( - accessTokenResponse); + when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + SecurityContextHolder.getContext().setAuthentication(this.authentication); Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); + assertThat(authorizedClient.getPrincipalName()).isEqualTo("test"); assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); } @Test public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() { this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); assertThatCode(() -> getDefaultRequestAttributes()) @@ -353,8 +395,11 @@ public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { public void filterWhenAuthorizedClientThenAuthorizationHeader() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -366,9 +411,12 @@ public void filterWhenAuthorizedClientThenAuthorizationHeader() { public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .header(HttpHeaders.AUTHORIZATION, "Existing") .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -384,47 +432,43 @@ public void filterWhenRefreshRequiredThenRefresh() { .expiresIn(3600) .refreshToken("refresh-1") .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(response); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); - verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); + verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); + verify(this.authorizedClientRepository).saveAuthorizedClient( + this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken()); List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); + assertThat(requests).hasSize(1); ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); } @Test @@ -434,62 +478,67 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh .expiresIn(3600) // .refreshToken(xxx) // No refreshToken in response .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + + RestOperations refreshTokenClient = mock(RestOperations.class); + when(refreshTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .thenReturn(new ResponseEntity(response, HttpStatus.OK)); + DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + refreshTokenTokenResponseClient.setRestOperations(refreshTokenClient); + + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + authorizedClientProvider.setAccessTokenResponseClient(refreshTokenTokenResponseClient); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); + verify(refreshTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); - assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(refreshToken); + assertThat(newAuthorizedClient.getRefreshToken().getTokenValue()).isEqualTo(refreshToken.getTokenValue()); List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); + assertThat(requests).hasSize(1); ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); } @Test public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { this.registration = TestClientRegistrations.clientCredentials().build(); - - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -519,27 +568,26 @@ public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); - verify(clientCredentialsTokenResponseClient).getTokenResponse(any()); + verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); @@ -558,54 +606,46 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() .expiresIn(3600) .refreshToken("refresh-1") .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(response); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - + this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange) - .block(); + this.function.filter(request, this.exchange).block(); + verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any(), any()); List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); + assertThat(requests).hasSize(1); ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); } @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -622,14 +662,14 @@ public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -647,8 +687,6 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() { // gh-6483 @Test public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, this.authorizedClientRepository); this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized this.function.setDefaultOAuth2AuthorizedClient(true); @@ -664,9 +702,12 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.registration, "principalName", this.accessToken); + when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()), eq(authentication), eq(servletRequest))).thenReturn(authorizedClient); + when(this.clientRegistrationRepository.findByRegistrationId(eq(authentication.getAuthorizedClientRegistrationId()))).thenReturn(this.registration); + // Default request attributes set final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com")) .attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build(); @@ -698,8 +739,6 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { @Test public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, this.authorizedClientRepository); // this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized this.function.setDefaultOAuth2AuthorizedClient(true); @@ -729,8 +768,6 @@ public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvail @Test public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, this.authorizedClientRepository); this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized this.function.destroy(); // Hooks.onLastOperator() released this.function.setDefaultOAuth2AuthorizedClient(true); diff --git a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java index b995ffb61db..636bc53fd6f 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,11 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction; import org.springframework.web.reactive.function.client.WebClient; @@ -31,11 +35,28 @@ public class WebClientConfig { @Bean - WebClient webClient(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository); + WebClient webClient(OAuth2AuthorizedClientManager authorizedClientManager) { + ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = + new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); oauth2.setDefaultOAuth2AuthorizedClient(true); return WebClient.builder() .apply(oauth2.oauth2Configuration()) .build(); } + + @Bean + OAuth2AuthorizedClientManager authorizedClientManager(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; + } }