Skip to content

Add authenticationDetailsSource to OAuth2TokenRevocationEndpointFilter #1667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
Expand All @@ -37,6 +39,7 @@
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -66,6 +69,8 @@ public final class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFil

private final RequestMatcher tokenRevocationEndpointMatcher;

private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();

private AuthenticationConverter authenticationConverter;

private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendRevocationSuccessResponse;
Expand Down Expand Up @@ -111,6 +116,12 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
Authentication tokenRevocationAuthentication = this.authenticationConverter.convert(request);
Authentication tokenRevocationAuthenticationResult = this.authenticationManager
.authenticate(tokenRevocationAuthentication);

if (tokenRevocationAuthenticationResult instanceof AbstractAuthenticationToken) {
((AbstractAuthenticationToken) tokenRevocationAuthenticationResult)
.setDetails(this.authenticationDetailsSource.buildDetails(request));
}

this.authenticationSuccessHandler.onAuthenticationSuccess(request, response,
tokenRevocationAuthenticationResult);
}
Expand All @@ -123,6 +134,18 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
}
}

/**
* Sets the {@link AuthenticationDetailsSource} used for building an authentication
* details instance from {@link HttpServletRequest}.
* @param authenticationDetailsSource the {@link AuthenticationDetailsSource} used for
* building an authentication details instance from {@link HttpServletRequest}
*/
public void setAuthenticationDetailsSource(
AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
Assert.notNull(authenticationDetailsSource, "authenticationDetailsSource cannot be null");
this.authenticationDetailsSource = authenticationDetailsSource;
}

/**
* Sets the {@link AuthenticationConverter} used when attempting to extract a Revoke
* Token Request from {@link HttpServletRequest} to an instance of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
Expand All @@ -45,13 +46,15 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetails;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -102,6 +105,13 @@ public void constructorWhenTokenRevocationEndpointUriNullThenThrowIllegalArgumen
.hasMessage("tokenRevocationEndpointUri cannot be empty");
}

@Test
public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authenticationDetailsSource cannot be null");
}

@Test
public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null))
Expand Down Expand Up @@ -198,6 +208,40 @@ public void doFilterWhenTokenRevocationRequestValidThenSuccessResponse() throws
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
}

@Test
public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient,
ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());

MockHttpServletRequest request = createTokenRevocationRequest();

AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource = mock(
AuthenticationDetailsSource.class);
WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails(request);
given(authenticationDetailsSource.buildDetails(any())).willReturn(webAuthenticationDetails);
this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);

OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token",
Instant.now(), Instant.now().plus(Duration.ofHours(1)),
new HashSet<>(Arrays.asList("scope1", "scope2")));
OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthentication = new OAuth2TokenRevocationAuthenticationToken(
accessToken, clientPrincipal);

given(this.authenticationManager.authenticate(any())).willReturn(tokenRevocationAuthentication);

SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(clientPrincipal);
SecurityContextHolder.setContext(securityContext);

MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);

this.filter.doFilter(request, response, filterChain);

verify(authenticationDetailsSource).buildDetails(any());
}

@Test
public void doFilterWhenCustomAuthenticationConverterThenUsed() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
Expand Down