diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenEndpointConfigurer.java index 5a8fe94f5..0f0fbd204 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenEndpointConfigurer.java @@ -39,6 +39,7 @@ import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; +import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenResponseEnhancer; import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientCredentialsAuthenticationConverter; @@ -67,6 +68,7 @@ public final class OAuth2TokenEndpointConfigurer extends AbstractOAuth2Configure private Consumer> authenticationProvidersConsumer = (authenticationProviders) -> {}; private AuthenticationSuccessHandler accessTokenResponseHandler; private AuthenticationFailureHandler errorResponseHandler; + private OAuth2TokenResponseEnhancer oauth2TokenResponseEnhancer; /** * Restrict for internal use only. @@ -156,6 +158,18 @@ public OAuth2TokenEndpointConfigurer errorResponseHandler(AuthenticationFailureH return this; } + /** + * Sets the {@link OAuth2TokenResponseEnhancer} used for enhance {@link OAuth2AccessTokenResponse#additionalParameters}. + * + * @param oauth2TokenResponseEnhancer the {@link OAuth2TokenResponseEnhancer} used for + * enhance additional parameters for OAuth2 Token Response. + * @return the {@link OAuth2TokenEndpointConfigurer} for further configuration + */ + public OAuth2TokenEndpointConfigurer oauth2TokenResponseEnhancer(OAuth2TokenResponseEnhancer oauth2TokenResponseEnhancer) { + this.oauth2TokenResponseEnhancer = oauth2TokenResponseEnhancer; + return this; + } + @Override void init(HttpSecurity httpSecurity) { AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); @@ -193,6 +207,9 @@ void configure(HttpSecurity httpSecurity) { if (this.errorResponseHandler != null) { tokenEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler); } + if (this.oauth2TokenResponseEnhancer != null) { + tokenEndpointFilter.setOAuth2TokenResponseEnhancer(this.oauth2TokenResponseEnhancer); + } httpSecurity.addFilterAfter(postProcess(tokenEndpointFilter), FilterSecurityInterceptor.class); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 865b239a1..ffd2f08bc 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -110,6 +110,7 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter { private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse; private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + private OAuth2TokenResponseEnhancer oauth2TokenResponseEnhancer = this::getAdditionalParameters; /** * Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters. @@ -214,6 +215,18 @@ public void setAuthenticationFailureHandler(AuthenticationFailureHandler authent this.authenticationFailureHandler = authenticationFailureHandler; } + /** + * Sets the {@link OAuth2TokenResponseEnhancer} used for enhance {@link OAuth2AccessTokenResponse#additionalParameters}. + * + * @param oauth2TokenResponseEnhancer the {@link OAuth2TokenResponseEnhancer} used for + * enhance additional parameters for OAuth2 Token Response. + * + */ + public void setOAuth2TokenResponseEnhancer(OAuth2TokenResponseEnhancer oauth2TokenResponseEnhancer) { + Assert.notNull(oauth2TokenResponseEnhancer, "oauth2TokenResponseEnhancer cannot be null"); + this.oauth2TokenResponseEnhancer = oauth2TokenResponseEnhancer; + } + private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException { @@ -222,7 +235,6 @@ private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResp OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken(); OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken(); - Map additionalParameters = accessTokenAuthentication.getAdditionalParameters(); OAuth2AccessTokenResponse.Builder builder = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) @@ -234,6 +246,7 @@ private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResp if (refreshToken != null) { builder.refreshToken(refreshToken.getTokenValue()); } + Map additionalParameters = oauth2TokenResponseEnhancer.enhance(accessTokenAuthentication); if (!CollectionUtils.isEmpty(additionalParameters)) { builder.additionalParameters(additionalParameters); } @@ -251,6 +264,10 @@ private void sendErrorResponse(HttpServletRequest request, HttpServletResponse r this.errorHttpResponseConverter.write(error, null, httpResponse); } + private Map getAdditionalParameters(OAuth2AccessTokenAuthenticationToken accessTokenAuthentication) { + return accessTokenAuthentication.getAdditionalParameters(); + } + private static void throwError(String errorCode, String parameterName) { OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, DEFAULT_ERROR_URI); throw new OAuth2AuthenticationException(error); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenResponseEnhancer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenResponseEnhancer.java new file mode 100644 index 000000000..8205f8f4b --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenResponseEnhancer.java @@ -0,0 +1,41 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; + +import java.util.Map; + +/** + * Customize additional parameters for OAuth2 Token Response + * + * @see org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse + * @author hccake + * @since 0.4.0 + */ +@FunctionalInterface +public interface OAuth2TokenResponseEnhancer { + + /** + * Provide an additional parameter map to enhance OAuth2 Token Response + * + * @param accessTokenAuthentication An {@link Authentication} implementation used when issuing an + * * OAuth 2.0 Access Token and (optional) Refresh Token. + * @return an additional parameter map + */ + Map enhance(OAuth2AccessTokenAuthenticationToken accessTokenAuthentication); +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index b252ce5c1..d613c8e92 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -19,6 +19,7 @@ import java.time.Instant; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Map; @@ -145,6 +146,13 @@ public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentExcep .hasMessage("authenticationFailureHandler cannot be null"); } + @Test + public void setOAuth2TokenResponseEnhancerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setOAuth2TokenResponseEnhancer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("oauth2TokenResponseEnhancer cannot be null"); + } + @Test public void doFilterWhenNotTokenRequestThenNotProcessed() throws Exception { String requestUri = "/path"; @@ -571,6 +579,82 @@ public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exce verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any()); } + @Test + public void doFilterWhenAuthorizationCodeTokenRequestThenAccessTokenResponseWithCustomTokenResponseEnhancer() throws Exception { + this.filter.setOAuth2TokenResponseEnhancer(accessTokenAuthentication -> { + Map enhanceParameters = new HashMap<>(); + if (accessTokenAuthentication.getAccessToken().getScopes().contains("scope1")) { + enhanceParameters.put("some-info-key", "some-info-value"); + } + if (accessTokenAuthentication.getAdditionalParameters() != null) { + enhanceParameters.putAll(accessTokenAuthentication.getAdditionalParameters()); + } + return enhanceParameters; + }); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plus(Duration.ofHours(1)), + new HashSet<>(Arrays.asList("scope1", "scope2"))); + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken( + "refresh-token", Instant.now(), Instant.now().plus(Duration.ofDays(1))); + Map additionalParameters = Collections.singletonMap("custom-param", "custom-value"); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + new OAuth2AccessTokenAuthenticationToken( + registeredClient, clientPrincipal, accessToken, refreshToken, additionalParameters); + + when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + + MockHttpServletRequest request = createAuthorizationCodeTokenRequest(registeredClient); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + ArgumentCaptor authorizationCodeAuthenticationCaptor = + ArgumentCaptor.forClass(OAuth2AuthorizationCodeAuthenticationToken.class); + verify(this.authenticationManager).authenticate(authorizationCodeAuthenticationCaptor.capture()); + + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = + authorizationCodeAuthenticationCaptor.getValue(); + assertThat(authorizationCodeAuthentication.getCode()).isEqualTo( + request.getParameter(OAuth2ParameterNames.CODE)); + assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo( + request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); + assertThat(authorizationCodeAuthentication.getAdditionalParameters()) + .containsExactly(entry("custom-param-1", "custom-value-1")); + assertThat(authorizationCodeAuthentication.getDetails()) + .asInstanceOf(type(WebAuthenticationDetails.class)) + .extracting(WebAuthenticationDetails::getRemoteAddress) + .isEqualTo(REMOTE_ADDRESS); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); + + OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken(); + assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType()); + assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue()); + assertThat(accessTokenResult.getIssuedAt()).isBetween( + accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1)); + assertThat(accessTokenResult.getExpiresAt()).isBetween( + accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1)); + assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); + assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(refreshToken.getTokenValue()); + assertThat(accessTokenResponse.getAdditionalParameters()) + .containsOnly(entry("custom-param", "custom-value"), entry("some-info-key", "some-info-value")); + } + + private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode, MockHttpServletRequest request) throws Exception {