Skip to content

Add OAuth2TokenResponseEnhancer to enhance the access token response #961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenResponseEnhancer;
import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientCredentialsAuthenticationConverter;
Expand Down Expand Up @@ -67,6 +68,7 @@ public final class OAuth2TokenEndpointConfigurer extends AbstractOAuth2Configure
private Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer = (authenticationProviders) -> {};
private AuthenticationSuccessHandler accessTokenResponseHandler;
private AuthenticationFailureHandler errorResponseHandler;
private OAuth2TokenResponseEnhancer oauth2TokenResponseEnhancer;

/**
* Restrict for internal use only.
Expand Down Expand Up @@ -156,6 +158,18 @@ public OAuth2TokenEndpointConfigurer errorResponseHandler(AuthenticationFailureH
return this;
}

/**
* Sets the {@link OAuth2TokenResponseEnhancer} used for enhance {@link OAuth2AccessTokenResponse#additionalParameters}.
*
* @param oauth2TokenResponseEnhancer the {@link OAuth2TokenResponseEnhancer} used for
* enhance additional parameters for OAuth2 Token Response.
* @return the {@link OAuth2TokenEndpointConfigurer} for further configuration
*/
public OAuth2TokenEndpointConfigurer oauth2TokenResponseEnhancer(OAuth2TokenResponseEnhancer oauth2TokenResponseEnhancer) {
this.oauth2TokenResponseEnhancer = oauth2TokenResponseEnhancer;
return this;
}

@Override
void init(HttpSecurity httpSecurity) {
AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity);
Expand Down Expand Up @@ -193,6 +207,9 @@ void configure(HttpSecurity httpSecurity) {
if (this.errorResponseHandler != null) {
tokenEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler);
}
if (this.oauth2TokenResponseEnhancer != null) {
tokenEndpointFilter.setOAuth2TokenResponseEnhancer(this.oauth2TokenResponseEnhancer);
}
httpSecurity.addFilterAfter(postProcess(tokenEndpointFilter), FilterSecurityInterceptor.class);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
private AuthenticationConverter authenticationConverter;
private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse;
private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
private OAuth2TokenResponseEnhancer oauth2TokenResponseEnhancer = this::getAdditionalParameters;

/**
* Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
Expand Down Expand Up @@ -214,6 +215,18 @@ public void setAuthenticationFailureHandler(AuthenticationFailureHandler authent
this.authenticationFailureHandler = authenticationFailureHandler;
}

/**
* Sets the {@link OAuth2TokenResponseEnhancer} used for enhance {@link OAuth2AccessTokenResponse#additionalParameters}.
*
* @param oauth2TokenResponseEnhancer the {@link OAuth2TokenResponseEnhancer} used for
* enhance additional parameters for OAuth2 Token Response.
*
*/
public void setOAuth2TokenResponseEnhancer(OAuth2TokenResponseEnhancer oauth2TokenResponseEnhancer) {
Assert.notNull(oauth2TokenResponseEnhancer, "oauth2TokenResponseEnhancer cannot be null");
this.oauth2TokenResponseEnhancer = oauth2TokenResponseEnhancer;
}

private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws IOException {

Expand All @@ -222,7 +235,6 @@ private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResp

OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
Map<String, Object> additionalParameters = accessTokenAuthentication.getAdditionalParameters();

OAuth2AccessTokenResponse.Builder builder =
OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
Expand All @@ -234,6 +246,7 @@ private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResp
if (refreshToken != null) {
builder.refreshToken(refreshToken.getTokenValue());
}
Map<String, Object> additionalParameters = oauth2TokenResponseEnhancer.enhance(accessTokenAuthentication);
if (!CollectionUtils.isEmpty(additionalParameters)) {
builder.additionalParameters(additionalParameters);
}
Expand All @@ -251,6 +264,10 @@ private void sendErrorResponse(HttpServletRequest request, HttpServletResponse r
this.errorHttpResponseConverter.write(error, null, httpResponse);
}

private Map<String, Object> getAdditionalParameters(OAuth2AccessTokenAuthenticationToken accessTokenAuthentication) {
return accessTokenAuthentication.getAdditionalParameters();
}

private static void throwError(String errorCode, String parameterName) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, DEFAULT_ERROR_URI);
throw new OAuth2AuthenticationException(error);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2020-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.web;

import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;

import java.util.Map;

/**
* Customize additional parameters for OAuth2 Token Response
*
* @see org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
* @author hccake
* @since 0.4.0
*/
@FunctionalInterface
public interface OAuth2TokenResponseEnhancer {

/**
* Provide an additional parameter map to enhance OAuth2 Token Response
*
* @param accessTokenAuthentication An {@link Authentication} implementation used when issuing an
* * OAuth 2.0 Access Token and (optional) Refresh Token.
* @return an additional parameter map
*/
Map<String, Object> enhance(OAuth2AccessTokenAuthenticationToken accessTokenAuthentication);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

Expand Down Expand Up @@ -145,6 +146,13 @@ public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentExcep
.hasMessage("authenticationFailureHandler cannot be null");
}

@Test
public void setOAuth2TokenResponseEnhancerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.filter.setOAuth2TokenResponseEnhancer(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("oauth2TokenResponseEnhancer cannot be null");
}

@Test
public void doFilterWhenNotTokenRequestThenNotProcessed() throws Exception {
String requestUri = "/path";
Expand Down Expand Up @@ -571,6 +579,82 @@ public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exce
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
}

@Test
public void doFilterWhenAuthorizationCodeTokenRequestThenAccessTokenResponseWithCustomTokenResponseEnhancer() throws Exception {
this.filter.setOAuth2TokenResponseEnhancer(accessTokenAuthentication -> {
Map<String, Object> enhanceParameters = new HashMap<>();
if (accessTokenAuthentication.getAccessToken().getScopes().contains("scope1")) {
enhanceParameters.put("some-info-key", "some-info-value");
}
if (accessTokenAuthentication.getAdditionalParameters() != null) {
enhanceParameters.putAll(accessTokenAuthentication.getAdditionalParameters());
}
return enhanceParameters;
});

RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(
registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "token",
Instant.now(), Instant.now().plus(Duration.ofHours(1)),
new HashSet<>(Arrays.asList("scope1", "scope2")));
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
"refresh-token", Instant.now(), Instant.now().plus(Duration.ofDays(1)));
Map<String, Object> additionalParameters = Collections.singletonMap("custom-param", "custom-value");
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
new OAuth2AccessTokenAuthenticationToken(
registeredClient, clientPrincipal, accessToken, refreshToken, additionalParameters);

when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication);

SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(clientPrincipal);
SecurityContextHolder.setContext(securityContext);

MockHttpServletRequest request = createAuthorizationCodeTokenRequest(registeredClient);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);

this.filter.doFilter(request, response, filterChain);

verifyNoInteractions(filterChain);

ArgumentCaptor<OAuth2AuthorizationCodeAuthenticationToken> authorizationCodeAuthenticationCaptor =
ArgumentCaptor.forClass(OAuth2AuthorizationCodeAuthenticationToken.class);
verify(this.authenticationManager).authenticate(authorizationCodeAuthenticationCaptor.capture());

OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
authorizationCodeAuthenticationCaptor.getValue();
assertThat(authorizationCodeAuthentication.getCode()).isEqualTo(
request.getParameter(OAuth2ParameterNames.CODE));
assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo(
request.getParameter(OAuth2ParameterNames.REDIRECT_URI));
assertThat(authorizationCodeAuthentication.getAdditionalParameters())
.containsExactly(entry("custom-param-1", "custom-value-1"));
assertThat(authorizationCodeAuthentication.getDetails())
.asInstanceOf(type(WebAuthenticationDetails.class))
.extracting(WebAuthenticationDetails::getRemoteAddress)
.isEqualTo(REMOTE_ADDRESS);

assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);

OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken();
assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType());
assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue());
assertThat(accessTokenResult.getIssuedAt()).isBetween(
accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1));
assertThat(accessTokenResult.getExpiresAt()).isBetween(
accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1));
assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes());
assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(refreshToken.getTokenValue());
assertThat(accessTokenResponse.getAdditionalParameters())
.containsOnly(entry("custom-param", "custom-value"), entry("some-info-key", "some-info-value"));
}


private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode,
MockHttpServletRequest request) throws Exception {

Expand Down