diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsentContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsentContext.java new file mode 100644 index 000000000..e5586e102 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsentContext.java @@ -0,0 +1,212 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.context.Context; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +/** + * A context that holds an {@link OAuth2AuthorizationConsent.Builder} and (optionally) additional information + * and is used when customizing the building of {@link OAuth2AuthorizationConsent}. + * + * @author Steve Riesenberg + * @since 0.2.1 + * @see Context + */ +public final class OAuth2AuthorizationConsentContext implements Context { + private final Map context; + + /** + * Constructs an {@code OAuth2AuthorizationConsentContext} using the provided parameters. + * + * @param context a {@code Map} of additional context information + */ + private OAuth2AuthorizationConsentContext(@Nullable Map context) { + this.context = new HashMap<>(); + if (!CollectionUtils.isEmpty(context)) { + this.context.putAll(context); + } + } + + /** + * Returns the {@link OAuth2AuthorizationConsent.Builder authorization consent builder}. + * + * @return the {@link OAuth2AuthorizationConsent.Builder} + */ + public OAuth2AuthorizationConsent.Builder getAuthorizationConsentBuilder() { + return get(OAuth2AuthorizationConsent.Builder.class); + } + + /** + * Returns the {@link Authentication} representing the {@code Principal} resource owner (or client). + * + * @param the type of the {@code Authentication} + * @return the {@link Authentication} representing the {@code Principal} resource owner (or client) + */ + @Nullable + public T getPrincipal() { + return get(Builder.PRINCIPAL_AUTHENTICATION_KEY); + } + + /** + * Returns the {@link RegisteredClient registered client}. + * + * @return the {@link RegisteredClient}, or {@code null} if not available + */ + @Nullable + public RegisteredClient getRegisteredClient() { + return get(RegisteredClient.class); + } + + /** + * Returns the {@link OAuth2Authorization authorization}. + * + * @return the {@link OAuth2Authorization}, or {@code null} if not available + */ + @Nullable + public OAuth2Authorization getAuthorization() { + return get(OAuth2Authorization.class); + } + + /** + * Returns the {@link OAuth2AuthorizationRequest authorization request}. + * + * @return the {@link OAuth2AuthorizationRequest}, or {@code null} if not available + */ + @Nullable + public OAuth2AuthorizationRequest getAuthorizationRequest() { + return get(OAuth2AuthorizationRequest.class); + } + + @SuppressWarnings("unchecked") + @Override + public V get(Object key) { + return (V) this.context.get(key); + } + + @Override + public boolean hasKey(Object key) { + return this.context.containsKey(key); + } + + /** + * Constructs a new {@link Builder} with the provided {@link OAuth2AuthorizationConsent.Builder}. + * + * @param authorizationConsentBuilder the {@link OAuth2AuthorizationConsent.Builder} to initialize the builder + * @return the {@link Builder} + */ + public static OAuth2AuthorizationConsentContext.Builder with(OAuth2AuthorizationConsent.Builder authorizationConsentBuilder) { + return new Builder(authorizationConsentBuilder); + } + + /** + * A builder for {@link OAuth2AuthorizationConsentContext}. + */ + public static final class Builder { + private static final String PRINCIPAL_AUTHENTICATION_KEY = + Authentication.class.getName().concat(".PRINCIPAL"); + private final Map context = new HashMap<>(); + + private Builder(OAuth2AuthorizationConsent.Builder authorizationConsentBuilder) { + Assert.notNull(authorizationConsentBuilder, "authorizationConsentBuilder cannot be null"); + put(OAuth2AuthorizationConsent.Builder.class, authorizationConsentBuilder); + } + + /** + * Sets the {@link Authentication} representing the {@code Principal} resource owner (or client). + * + * @param principal the {@link Authentication} representing the {@code Principal} resource owner (or client) + * @return the {@link Builder} for further configuration + */ + public Builder principal(Authentication principal) { + return put(PRINCIPAL_AUTHENTICATION_KEY, principal); + } + + /** + * Sets the {@link RegisteredClient registered client}. + * + * @param registeredClient the {@link RegisteredClient} + * @return the {@link Builder} for further configuration + */ + public Builder registeredClient(RegisteredClient registeredClient) { + return put(RegisteredClient.class, registeredClient); + } + + /** + * Sets the {@link OAuth2Authorization authorization}. + * + * @param authorization the {@link OAuth2Authorization} + * @return the {@link Builder} for further configuration + */ + public Builder authorization(OAuth2Authorization authorization) { + return put(OAuth2Authorization.class, authorization); + } + + /** + * Sets the {@link OAuth2AuthorizationRequest authorization request}. + * + * @param authorizationRequest the {@link OAuth2AuthorizationRequest} + * @return the {@link Builder} for further configuration + */ + public Builder authorizationRequest(OAuth2AuthorizationRequest authorizationRequest) { + return put(OAuth2AuthorizationRequest.class, authorizationRequest); + } + + /** + * Associates an attribute. + * + * @param key the key for the attribute + * @param value the value of the attribute + * @return the {@link OAuth2TokenContext.AbstractBuilder} for further configuration + */ + public Builder put(Object key, Object value) { + Assert.notNull(key, "key cannot be null"); + Assert.notNull(value, "value cannot be null"); + this.context.put(key, value); + return this; + } + + /** + * A {@code Consumer} of the attributes {@code Map} + * allowing the ability to add, replace, or remove. + * + * @param contextConsumer a {@link Consumer} of the attributes {@code Map} + * @return the {@link Builder} for further configuration + */ + public Builder context(Consumer> contextConsumer) { + contextConsumer.accept(this.context); + return this; + } + + /** + * Builds a new {@link OAuth2AuthorizationConsentContext}. + * + * @return the {@link OAuth2AuthorizationConsentContext} + */ + public OAuth2AuthorizationConsentContext build() { + return new OAuth2AuthorizationConsentContext(this.context); + } + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java index e1bcacda6..b2b325539 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java @@ -29,6 +29,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.config.Customizer; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; @@ -46,6 +47,7 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentContext; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; @@ -82,6 +84,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen private final OAuth2AuthorizationConsentService authorizationConsentService; private Supplier authorizationCodeGenerator = DEFAULT_AUTHORIZATION_CODE_GENERATOR::generateKey; private Function authenticationValidatorResolver = DEFAULT_AUTHENTICATION_VALIDATOR_RESOLVER; + private Customizer authorizationConsentCustomizer; /** * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationProvider} using the provided parameters. @@ -145,6 +148,30 @@ public void setAuthenticationValidatorResolver(Function + * The following context attributes are available: + *
    + *
  • The {@link OAuth2AuthorizationConsent.Builder} used to build the authorization consent + * prior to {@link OAuth2AuthorizationConsentService#save(OAuth2AuthorizationConsent)}
  • + *
  • The {@link Authentication authentication principal} of type + * {@link OAuth2AuthorizationCodeRequestAuthenticationToken}
  • + *
  • The {@link OAuth2Authorization} associated with the state token presented in the + * authorization consent request.
  • + *
  • The {@link OAuth2AuthorizationRequest} requiring the resource owner's consent.
  • + *
+ * + * @param authorizationConsentCustomizer the {@link Customizer} providing access to the + * {@link OAuth2AuthorizationConsentContext} containing an {@link OAuth2AuthorizationConsent.Builder} + */ + public void setAuthorizationConsentCustomizer(Customizer authorizationConsentCustomizer) { + Assert.notNull(authorizationConsentCustomizer, "authorizationConsentCustomizer cannot be null"); + this.authorizationConsentCustomizer = authorizationConsentCustomizer; + } + private Authentication authenticateAuthorizationRequest(Authentication authentication) throws AuthenticationException { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = (OAuth2AuthorizationCodeRequestAuthenticationToken) authentication; @@ -301,7 +328,8 @@ private Authentication authenticateAuthorizationConsent(Authentication authentic Set currentAuthorizedScopes = currentAuthorizationConsent != null ? currentAuthorizationConsent.getScopes() : Collections.emptySet(); - if (authorizedScopes.isEmpty() && currentAuthorizedScopes.isEmpty()) { + if (authorizedScopes.isEmpty() && currentAuthorizedScopes.isEmpty() + && authorizationCodeRequestAuthentication.getAdditionalParameters().isEmpty()) { // Authorization consent denied this.authorizationService.remove(authorization); throwError(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID, @@ -321,16 +349,30 @@ private Authentication authenticateAuthorizationConsent(Authentication authentic } } - if (!authorizedScopes.isEmpty() && !authorizedScopes.equals(currentAuthorizedScopes)) { - OAuth2AuthorizationConsent.Builder authorizationConsentBuilder; - if (currentAuthorizationConsent != null) { - authorizationConsentBuilder = OAuth2AuthorizationConsent.from(currentAuthorizationConsent); - } else { - authorizationConsentBuilder = OAuth2AuthorizationConsent.withId( - authorization.getRegisteredClientId(), authorization.getPrincipalName()); - } - authorizedScopes.forEach(authorizationConsentBuilder::scope); - OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build(); + OAuth2AuthorizationConsent.Builder authorizationConsentBuilder; + if (currentAuthorizationConsent != null) { + authorizationConsentBuilder = OAuth2AuthorizationConsent.from(currentAuthorizationConsent); + } else { + authorizationConsentBuilder = OAuth2AuthorizationConsent.withId( + authorization.getRegisteredClientId(), authorization.getPrincipalName()); + } + authorizedScopes.forEach(authorizationConsentBuilder::scope); + + if (this.authorizationConsentCustomizer != null) { + // @formatter:off + OAuth2AuthorizationConsentContext authorizationConsentContext = + OAuth2AuthorizationConsentContext.with(authorizationConsentBuilder) + .principal(authorizationCodeRequestAuthentication) + .registeredClient(registeredClient) + .authorization(authorization) + .authorizationRequest(authorizationRequest) + .build(); + // @formatter:on + this.authorizationConsentCustomizer.customize(authorizationConsentContext); + } + + OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build(); + if (!authorizationConsent.equals(currentAuthorizationConsent)) { this.authorizationConsentService.save(authorizationConsent); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index 7c4b3e8cb..6a80ee09c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -23,14 +23,17 @@ import java.text.MessageFormat; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; import java.util.Base64; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; +import org.assertj.core.matcher.AssertionMatcher; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -52,12 +55,14 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.crypto.password.NoOpPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -77,10 +82,13 @@ import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentContext; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper; @@ -495,6 +503,60 @@ public void requestWhenCustomConsentPageConfiguredThenRedirect() throws Exceptio assertThat(authorization).isNotNull(); } + @Test + public void requestWhenCustomConsentCustomizerConfiguredThenUsed() throws Exception { + this.spring.register(AuthorizationServerConfigurationCustomConsentRequest.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(ClientSettings.builder() + .requireAuthorizationConsent(true) + .setting("custom.allowed-authorities", "authority-1 authority-2") + .build()) + .build(); + this.registeredClientRepository.save(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .build(); + this.authorizationService.save(authorization); + + MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI) + .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) + .param("authority", "authority-1 authority-2") + .param(OAuth2ParameterNames.STATE, "state") + .with(user("principal"))) + .andExpect(status().is3xxRedirection()) + .andReturn(); + + String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + + String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); + OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); + + mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI) + .params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization)) + .header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))) + .andExpect(status().isOk()) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store"))) + .andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache"))) + .andExpect(jsonPath("$.access_token").isNotEmpty()) + .andExpect(jsonPath("$.access_token").value(new AssertionMatcher() { + @Override + public void assertion(String accessToken) throws AssertionError { + Jwt jwt = jwtDecoder.decode(accessToken); + assertThat(jwt.getClaimAsStringList(AUTHORITIES_CLAIM)) + .containsExactlyInAnyOrder("authority-1", "authority-2"); + } + })) + .andExpect(jsonPath("$.token_type").isNotEmpty()) + .andExpect(jsonPath("$.expires_in").isNotEmpty()) + .andExpect(jsonPath("$.refresh_token").isNotEmpty()) + .andExpect(jsonPath("$.scope").doesNotExist()) + .andReturn(); + + String json = mvcResult.getResponse().getContentAsString(); + } + @Test public void requestWhenAuthorizationEndpointCustomizedThenUsed() throws Exception { this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire(); @@ -693,6 +755,100 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h // @formatter:on } + @EnableWebSecurity + static class AuthorizationServerConfigurationCustomConsentRequest extends AuthorizationServerConfiguration { + @Autowired + private RegisteredClientRepository registeredClientRepository; + + @Autowired + private OAuth2AuthorizationService authorizationService; + + @Autowired + private OAuth2AuthorizationConsentService authorizationConsentService; + + // @formatter:off + @Bean + public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + OAuth2AuthorizationServerConfigurer authorizationServerConfigurer = + new OAuth2AuthorizationServerConfigurer<>(); + authorizationServerConfigurer + .authorizationEndpoint(authorizationEndpoint -> + authorizationEndpoint.authenticationProvider(createProvider())); + RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher(); + + http + .requestMatcher(endpointsMatcher) + .authorizeRequests(authorizeRequests -> + authorizeRequests.anyRequest().authenticated() + ) + .csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher)) + .apply(authorizationServerConfigurer); + return http.build(); + } + // @formatter:on + + @Bean + @Override + OAuth2TokenCustomizer jwtCustomizer() { + return context -> { + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) && + OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) { + OAuth2AuthorizationConsent authorizationConsent = authorizationConsentService.findById( + context.getRegisteredClient().getId(), context.getPrincipal().getName()); + + Set authorities = new HashSet<>(); + for (GrantedAuthority authority : authorizationConsent.getAuthorities()) { + authorities.add(authority.getAuthority()); + } + context.getClaims().claim(AUTHORITIES_CLAIM, authorities); + } + }; + } + + private AuthenticationProvider createProvider() { + OAuth2AuthorizationCodeRequestAuthenticationProvider authorizationCodeRequestAuthenticationProvider = + new OAuth2AuthorizationCodeRequestAuthenticationProvider( + this.registeredClientRepository, + this.authorizationService, + this.authorizationConsentService); + authorizationCodeRequestAuthenticationProvider.setAuthorizationConsentCustomizer(new ConsentCustomizer()); + + return authorizationCodeRequestAuthenticationProvider; + } + + static class ConsentCustomizer implements Customizer { + @Override + public void customize(OAuth2AuthorizationConsentContext authorizationConsentContext) { + OAuth2AuthorizationConsent.Builder authorizationConsentBuilder = + authorizationConsentContext.getAuthorizationConsentBuilder(); + OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = + authorizationConsentContext.getPrincipal(); + Map additionalParameters = + authorizationCodeRequestAuthentication.getAdditionalParameters(); + RegisteredClient registeredClient = authorizationConsentContext.getRegisteredClient(); + ClientSettings clientSettings = registeredClient.getClientSettings(); + + Set requestedAuthorities = authorities((String) additionalParameters.get("authority")); + Set allowedAuthorities = authorities(clientSettings.getSetting("custom.allowed-authorities")); + for (String requestedAuthority : requestedAuthorities) { + if (allowedAuthorities.contains(requestedAuthority)) { + authorizationConsentBuilder.authority(new SimpleGrantedAuthority(requestedAuthority)); + } + } + } + + private static Set authorities(String param) { + Set authorities = new HashSet<>(); + if (param != null) { + List authorityValues = Arrays.asList(param.split(" ")); + authorities.addAll(authorityValues); + } + + return authorities; + } + } + } + @EnableWebSecurity static class AuthorizationServerConfigurationCustomAuthorizationEndpoint extends AuthorizationServerConfiguration { // @formatter:off diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsentContextTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsentContextTests.java new file mode 100644 index 000000000..d964e96ef --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsentContextTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.junit.Test; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2AuthorizationConsentContext}. + * + * @author Steve Riesenberg + */ +public class OAuth2AuthorizationConsentContextTests { + + @Test + public void withWhenAuthorizationConsentBuilderNullThenIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationConsentContext.with(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationConsentBuilder cannot be null"); + } + + @Test + public void setWhenValueNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationConsentContext.Builder builder = OAuth2AuthorizationConsentContext + .with(OAuth2AuthorizationConsent.withId("some-client", "some-principal")); + assertThatThrownBy(() -> builder.principal(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.registeredClient(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.authorization(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.authorizationRequest(null)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.put(null, "")) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizationConsent.Builder authorizationConsentBuilder = OAuth2AuthorizationConsent + .withId("some-client", "some-principal"); + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "password"); + OAuth2AuthorizationCodeRequestAuthenticationToken authentication = + OAuth2AuthorizationCodeRequestAuthenticationToken.with("test-client", principal) + .authorizationUri("https://provider.com/oauth2/authorize") + .build(); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build(); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationRequest.class.getName()); + + OAuth2AuthorizationConsentContext context = OAuth2AuthorizationConsentContext + .with(authorizationConsentBuilder) + .principal(authentication) + .registeredClient(registeredClient) + .authorization(authorization) + .authorizationRequest(authorizationRequest) + .put("custom-key-1", "custom-value-1") + .context(ctx -> ctx.put("custom-key-2", "custom-value-2")) + .build(); + + assertThat(context.getAuthorizationConsentBuilder()).isEqualTo(authorizationConsentBuilder); + assertThat(context.getPrincipal()).isEqualTo(authentication); + assertThat(context.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(context.getAuthorization()).isEqualTo(authorization); + assertThat(context.getAuthorizationRequest()).isEqualTo(authorizationRequest); + assertThat(context.get("custom-key-1")).isEqualTo("custom-value-1"); + assertThat(context.get("custom-key-2")).isEqualTo("custom-value-2"); + } +} \ No newline at end of file diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java index 4f625c578..0caac28f4 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java @@ -29,8 +29,10 @@ import org.mockito.ArgumentCaptor; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.Customizer; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationCode; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2TokenType; @@ -41,8 +43,8 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.core.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentContext; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; @@ -129,6 +131,13 @@ public void setAuthenticationValidatorResolverWhenNullThenThrowIllegalArgumentEx .hasMessage("authenticationValidatorResolver cannot be null"); } + @Test + public void setAuthorizationConsentCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setAuthorizationConsentCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationConsentCustomizer cannot be null"); + } + @Test public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); @@ -773,6 +782,53 @@ public void authenticateWhenConsentRequestApproveAllThenReturnAuthorizationCode( OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertAuthorizationConsentRequestWithAuthorizationCodeResult(registeredClient, authorization, authenticationResult); + } + + @Test + public void authenticateWhenCustomAuthorizationConsentCustomizerThenUsed() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(this.principal.getName()) + .build(); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); + Set authorizedScopes = authorizationRequest.getScopes(); + OAuth2AuthorizationCodeRequestAuthenticationToken authentication = + authorizationConsentRequestAuthentication(registeredClient, this.principal) + .scopes(authorizedScopes) // Approve all scopes + .build(); + when(this.authorizationService.findByToken(eq(authentication.getState()), eq(STATE_TOKEN_TYPE))) + .thenReturn(authorization); + + @SuppressWarnings("unchecked") + Customizer authorizationConsentCustomizer = mock(Customizer.class); + this.authenticationProvider.setAuthorizationConsentCustomizer(authorizationConsentCustomizer); + + OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult = + (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + assertAuthorizationConsentRequestWithAuthorizationCodeResult(registeredClient, authorization, authenticationResult); + + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationConsentContext.class); + verify(authorizationConsentCustomizer).customize(contextCaptor.capture()); + + OAuth2AuthorizationConsentContext context = contextCaptor.getValue(); + assertThat((Authentication) context.getPrincipal()).isEqualTo(authentication); + assertThat(context.get(OAuth2AuthorizationConsent.Builder.class)).isInstanceOf(OAuth2AuthorizationConsent.Builder.class); + assertThat(context.get(OAuth2Authorization.class)).isInstanceOf(OAuth2Authorization.class); + assertThat(context.get(OAuth2AuthorizationRequest.class)).isInstanceOf(OAuth2AuthorizationRequest.class); + } + + private void assertAuthorizationConsentRequestWithAuthorizationCodeResult( + RegisteredClient registeredClient, + OAuth2Authorization authorization, + OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult) { + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); + Set authorizedScopes = authorizationRequest.getScopes(); + ArgumentCaptor authorizationConsentCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationConsent.class); verify(this.authorizationConsentService).save(authorizationConsentCaptor.capture()); OAuth2AuthorizationConsent authorizationConsent = authorizationConsentCaptor.getValue();