From 47865f0d8f2f986795497d1bf3f1ec684512d4b4 Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Wed, 29 Apr 2020 23:29:20 -0400 Subject: [PATCH 01/10] Authorization Code generation in Authorization Endpoint Filter Functionality Covered ------------------------- - Validation of client id & response type - Generate authorization code - Save Authorization - Redirect to client redirect uri - Send error in http response - Send error as query params in redirect uri Validations covered ------------------- - Client Id should be mandatory - Client Id should be registered - Client Id should be configured with Authorization grant type - Response type is mandaotry - Response type should be code --- .../OAuth2AuthorizationEndpointFilter.java | 154 +++++++++++++++++- 1 file changed, 147 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index c8a143ca5..a7f464620 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -15,33 +15,173 @@ */ package org.springframework.security.oauth2.server.authorization.web; +import java.io.IOException; +import java.util.Optional; +import java.util.stream.Stream; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpStatus; import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.web.RedirectStrategy; +import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; - /** * @author Joe Grandja + * @author Paurav Munshi */ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { private Converter authorizationRequestConverter; private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; private StringKeyGenerator codeGenerator; - + private RedirectStrategy authorizationRedirectStrategy; + + private static final OAuth2Error CLIENT_ID_ABSENT_ERROR = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,"Request does not contain client id parameter",null); + private static final OAuth2Error CLIENT_ID_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED,"Can't validate the client id provided with the request",null); + private static final OAuth2Error RESPONSE_TYPE_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE,"Response type should be present and it should be 'code'",null); + private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE,"The provided client does not support Authorization Code grant",null); + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + + RegisteredClient client = null; + OAuth2AuthorizationRequest authorizationRequest = null; + OAuth2Authorization authorization = null; + + try { + client = fetchRegisteredClient(request); + + authorizationRequest = authorizationRequestConverter.convert(request); + validateAuthorizationRequest(authorizationRequest,client); + + String code = codeGenerator.generateKey(); + authorization = buildOAuth2Authorization(client,authorizationRequest,code); + authorizationService.save(authorization); + + this.authorizationRedirectStrategy.sendRedirect(request, response, authorizationRequest.getRedirectUri()); + }catch(OAuth2AuthorizationException authorizationException) { + OAuth2Error authorizationError = authorizationException.getError(); + + if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) + || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED) + || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE)) + sendErrorInResponse(response, authorizationError); + + if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)) + sendErrorInRedirect(request, response, authorizationError, authorizationRequest.getRedirectUri()); + } } + + private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { + String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); + if(StringUtils.isEmpty(clientId)) + throw new OAuth2AuthorizationException(CLIENT_ID_ABSENT_ERROR); + + RegisteredClient client = registeredClientRepository.findByClientId(clientId); + if(client==null) + throw new OAuth2AuthorizationException(CLIENT_ID_NOT_FOUND_ERROR); + + boolean isAuthoirzationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes()) + .anyMatch(grantType -> grantType.equals(AuthorizationGrantType.AUTHORIZATION_CODE)); + if(!isAuthoirzationGrantAllowed) + throw new OAuth2AuthorizationException(AUTHZ_CODE_NOT_SUPPORTED_ERROR); + + return client; + + } + + private OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, + OAuth2AuthorizationRequest authorizationRequest, String code) { + OAuth2Authorization authorization = OAuth2Authorization.createBuilder() + .clientId(authorizationRequest.getClientId()) + .addAttribute(OAuth2ParameterNames.CODE, code) + .build(); + + return authorization; + } + + + private void validateAuthorizationRequest(OAuth2AuthorizationRequest authzRequest, RegisteredClient client) { + OAuth2AuthorizationResponseType responseType = Optional.ofNullable(authzRequest.getResponseType()) + .orElseThrow(() -> new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR)); + + if(!responseType.equals(OAuth2AuthorizationResponseType.CODE)) + throw new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR); + + } + + private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException { + response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), authorizationError.getErrorCode()+":"+authorizationError.getDescription()); + } + + private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, OAuth2Error authorizationError, String redirectUri) throws IOException { + String finalRedirectURI = new StringBuilder(redirectUri) + .append("?").append("error_code=").append(authorizationError.getErrorCode()) + .append("&").append("error_description=").append(authorizationError.getDescription()) + .toString(); + + this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI); + } + public Converter getAuthorizationRequestConverter() { + return authorizationRequestConverter; + } + + public void setAuthorizationRequestConverter( + Converter authorizationRequestConverter) { + this.authorizationRequestConverter = authorizationRequestConverter; + } + + public RegisteredClientRepository getRegisteredClientRepository() { + return registeredClientRepository; + } + + public void setRegisteredClientRepository(RegisteredClientRepository registeredClientRepository) { + this.registeredClientRepository = registeredClientRepository; + } + + public OAuth2AuthorizationService getAuthorizationService() { + return authorizationService; + } + + public void setAuthorizationService(OAuth2AuthorizationService authorizationService) { + this.authorizationService = authorizationService; + } + + public StringKeyGenerator getCodeGenerator() { + return codeGenerator; + } + + public void setCodeGenerator(StringKeyGenerator codeGenerator) { + this.codeGenerator = codeGenerator; + } + + public RedirectStrategy getAuthorizationRedirectStrategy() { + return authorizationRedirectStrategy; + } + + public void getAuthorizationRedirectStrategy(RedirectStrategy redirectStrategy) { + this.authorizationRedirectStrategy = redirectStrategy; + } + } From 8d5685190d4443a024cc65908fb9502107ac21ef Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Thu, 30 Apr 2020 00:31:52 -0400 Subject: [PATCH 02/10] Enhancements in Authorization Code Endpoint Filter - On successfull redirect code was not being sent in the Redirect URL - validations methods were private, so made them protected --- .../web/OAuth2AuthorizationEndpointFilter.java | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index a7f464620..e6a3e63b1 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -77,7 +77,7 @@ protected void doFilterInternal(HttpServletRequest request, authorization = buildOAuth2Authorization(client,authorizationRequest,code); authorizationService.save(authorization); - this.authorizationRedirectStrategy.sendRedirect(request, response, authorizationRequest.getRedirectUri()); + sendCodeOnSuccess(request, response, authorizationRequest, code); }catch(OAuth2AuthorizationException authorizationException) { OAuth2Error authorizationError = authorizationException.getError(); @@ -92,7 +92,7 @@ protected void doFilterInternal(HttpServletRequest request, } - private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { + protected RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); if(StringUtils.isEmpty(clientId)) throw new OAuth2AuthorizationException(CLIENT_ID_ABSENT_ERROR); @@ -110,7 +110,7 @@ private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throw } - private OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, + protected OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, OAuth2AuthorizationRequest authorizationRequest, String code) { OAuth2Authorization authorization = OAuth2Authorization.createBuilder() .clientId(authorizationRequest.getClientId()) @@ -121,7 +121,7 @@ private OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, } - private void validateAuthorizationRequest(OAuth2AuthorizationRequest authzRequest, RegisteredClient client) { + protected void validateAuthorizationRequest(OAuth2AuthorizationRequest authzRequest, RegisteredClient client) { OAuth2AuthorizationResponseType responseType = Optional.ofNullable(authzRequest.getResponseType()) .orElseThrow(() -> new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR)); @@ -130,6 +130,15 @@ private void validateAuthorizationRequest(OAuth2AuthorizationRequest authzReques } + private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response, + OAuth2AuthorizationRequest authorizationRequest, String code) throws IOException { + String redirectUri = new StringBuilder(authorizationRequest.getRedirectUri()) + .append("?").append("code=").append(code) + .toString(); + + this.authorizationRedirectStrategy.sendRedirect(request, response, redirectUri); + } + private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException { response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), authorizationError.getErrorCode()+":"+authorizationError.getDescription()); } From 8f0644e9123a3af9d280771b084c22cc69e9d69b Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Fri, 1 May 2020 02:30:28 -0400 Subject: [PATCH 03/10] Enhancements & Fixes - Adding key constants for issues time and code used flag in OAuth2Authorization - Changing error code when Authorization grant is not supported by client - Change in error method for Unauthorized client - Corrected the validation for AuthorizationGrantType check - Adding attributes Issue time and Code use flag in attributes map - Adding state to redirect uri for both success and failure - Changing method name from get to set as it was wrongly used - Adding a basic test class which will be enhanced more --- .../OAuth2AuthorizationEndpointFilter.java | 47 ++++++++---- ...OAuth2AuthorizationEndpointFilterTest.java | 71 +++++++++++++++++++ 2 files changed, 103 insertions(+), 15 deletions(-) create mode 100644 core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index e6a3e63b1..412624813 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -16,7 +16,11 @@ package org.springframework.security.oauth2.server.authorization.web; import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Stream; import javax.servlet.FilterChain; @@ -56,7 +60,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { private static final OAuth2Error CLIENT_ID_ABSENT_ERROR = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,"Request does not contain client id parameter",null); private static final OAuth2Error CLIENT_ID_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED,"Can't validate the client id provided with the request",null); private static final OAuth2Error RESPONSE_TYPE_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE,"Response type should be present and it should be 'code'",null); - private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE,"The provided client does not support Authorization Code grant",null); + private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT,"The provided client is not authorized to request authorization code",null); @Override protected void doFilterInternal(HttpServletRequest request, @@ -82,12 +86,12 @@ protected void doFilterInternal(HttpServletRequest request, OAuth2Error authorizationError = authorizationException.getError(); if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) - || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED) - || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE)) + || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) sendErrorInResponse(response, authorizationError); - if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)) - sendErrorInRedirect(request, response, authorizationError, authorizationRequest.getRedirectUri()); + if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE) + || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) + sendErrorInRedirect(request, response, authorizationRequest, authorizationError, authorizationRequest.getRedirectUri()); } } @@ -102,7 +106,7 @@ protected RegisteredClient fetchRegisteredClient(HttpServletRequest request) thr throw new OAuth2AuthorizationException(CLIENT_ID_NOT_FOUND_ERROR); boolean isAuthoirzationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes()) - .anyMatch(grantType -> grantType.equals(AuthorizationGrantType.AUTHORIZATION_CODE)); + .anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE)); if(!isAuthoirzationGrantAllowed) throw new OAuth2AuthorizationException(AUTHZ_CODE_NOT_SUPPORTED_ERROR); @@ -115,6 +119,10 @@ protected OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, OAuth2Authorization authorization = OAuth2Authorization.createBuilder() .clientId(authorizationRequest.getClientId()) .addAttribute(OAuth2ParameterNames.CODE, code) + .addAttribute(OAuth2Authorization.ISSUED_AT, Instant.now()) + .addAttribute(OAuth2Authorization.CODE_USED, new AtomicBoolean(false)) + .addAttribute(OAuth2ParameterNames.SCOPE, Optional.ofNullable(authorizationRequest.getScopes()) + .filter(scopes -> !scopes.isEmpty()).orElse(client.getScopes())) .build(); return authorization; @@ -132,10 +140,13 @@ protected void validateAuthorizationRequest(OAuth2AuthorizationRequest authzRequ private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response, OAuth2AuthorizationRequest authorizationRequest, String code) throws IOException { - String redirectUri = new StringBuilder(authorizationRequest.getRedirectUri()) - .append("?").append("code=").append(code) - .toString(); + StringBuilder urlBuilder = new StringBuilder(authorizationRequest.getRedirectUri()) + .append(authorizationRequest.getRedirectUri().contains("?") ? "&" : "?") + .append(OAuth2ParameterNames.CODE).append("=").append(code); + if(!StringUtils.isEmpty(authorizationRequest.getState())) + urlBuilder.append("?").append("state=").append(authorizationRequest.getState()); + String redirectUri = urlBuilder.toString(); this.authorizationRedirectStrategy.sendRedirect(request, response, redirectUri); } @@ -143,12 +154,18 @@ private void sendErrorInResponse(HttpServletResponse response, OAuth2Error autho response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), authorizationError.getErrorCode()+":"+authorizationError.getDescription()); } - private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, OAuth2Error authorizationError, String redirectUri) throws IOException { - String finalRedirectURI = new StringBuilder(redirectUri) - .append("?").append("error_code=").append(authorizationError.getErrorCode()) - .append("&").append("error_description=").append(authorizationError.getDescription()) - .toString(); + private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, + OAuth2AuthorizationRequest authorizationRequest,OAuth2Error authorizationError, + String redirectUri) throws IOException { + StringBuilder urlBuilder = new StringBuilder(redirectUri) + .append(redirectUri.contains("?") ? "&" : "?") + .append("error_code=").append(authorizationError.getErrorCode()) + .append("&").append("error_description=").append(authorizationError.getDescription()); + if(!StringUtils.isEmpty(authorizationRequest.getState())) + urlBuilder.append("?").append("state=").append(authorizationRequest.getState()); + + String finalRedirectURI = urlBuilder.toString(); this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI); } @@ -189,7 +206,7 @@ public RedirectStrategy getAuthorizationRedirectStrategy() { return authorizationRedirectStrategy; } - public void getAuthorizationRedirectStrategy(RedirectStrategy redirectStrategy) { + public void setAuthorizationRedirectStrategy(RedirectStrategy redirectStrategy) { this.authorizationRedirectStrategy = redirectStrategy; } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java new file mode 100644 index 000000000..6858349a0 --- /dev/null +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -0,0 +1,71 @@ +package org.springframework.security.oauth2.server.authorization.web; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.servlet.http.HttpServletRequest; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; +import org.springframework.security.web.RedirectStrategy; + + +/** + * Tests for {@link OAuth2AuthorizationEndpointFilter}. + * + * @author Paurav Munshi + */ + +public class OAuth2AuthorizationEndpointFilterTest { + + private OAuth2AuthorizationEndpointFilter filter; + + private RedirectStrategy authorizationRedirectStrategy = mock(RedirectStrategy.class); + private Converter authorizationConverter = mock(Converter.class); + private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class); + private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class); + private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class); + + @Before + public void setUp() { + filter = new OAuth2AuthorizationEndpointFilter(); + + filter.setAuthorizationRequestConverter(authorizationConverter); + filter.setAuthorizationService(authorizationService); + filter.setCodeGenerator(codeGenerator); + filter.setRegisteredClientRepository(registeredClientRepository); + filter.setAuthorizationRedirectStrategy(authorizationRedirectStrategy); + } + + @Test + public void testSettersAreSettingProperValue() { + OAuth2AuthorizationEndpointFilter blankFilter = new OAuth2AuthorizationEndpointFilter(); + + assertThat(blankFilter.getAuthorizationRedirectStrategy()).isNull(); + assertThat(blankFilter.getAuthorizationRequestConverter()).isNull(); + assertThat(blankFilter.getAuthorizationService()).isNull(); + assertThat(blankFilter.getCodeGenerator()).isNull(); + assertThat(blankFilter.getRegisteredClientRepository()).isNull(); + + blankFilter.setAuthorizationRequestConverter(authorizationConverter); + blankFilter.setAuthorizationService(authorizationService); + blankFilter.setCodeGenerator(codeGenerator); + blankFilter.setRegisteredClientRepository(registeredClientRepository); + blankFilter.setAuthorizationRedirectStrategy(authorizationRedirectStrategy); + + assertThat(blankFilter.getAuthorizationRedirectStrategy()).isEqualTo(authorizationRedirectStrategy); + assertThat(blankFilter.getAuthorizationRequestConverter()).isEqualTo(authorizationConverter); + assertThat(blankFilter.getAuthorizationService()).isEqualTo(authorizationService); + assertThat(blankFilter.getCodeGenerator()).isEqualTo(codeGenerator); + assertThat(blankFilter.getRegisteredClientRepository()).isEqualTo(registeredClientRepository); + } + +} From 46a1c35b456ccadafb4dbc8b2fcca5438d31e6f8 Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Mon, 4 May 2020 17:03:20 -0400 Subject: [PATCH 04/10] Enhancements & Fixes - Added a UUID based Authorization Code generator - Added an Ant Path request matcher to check endpoint validity - Added a new OAuth2 authorization request converter - Use DefaultRedirectStrategy from Spring - Added a check for user authentication present and authenticated - Added check for respone type parameter validity in request - Added check for redirect uri validity in request - Replaced StringBuilder with UriComponentsBuilder - Added JUnit tests for validation in section 4.1.1 of the RFC specs - Added a class which will be responsible for holding sever level messages as constants --- .../util/AuthorizationCodeKeyGenerator.java | 19 ++ .../OAuth2AuthorizationServerMessages.java | 16 + .../OAuth2AuthorizationEndpointFilter.java | 118 +++++--- .../OAuth2AuthorizationRequestConverter.java | 40 +++ .../client/TestRegisteredClients.java | 36 +++ ...OAuth2AuthorizationEndpointFilterTest.java | 284 +++++++++++++++++- 6 files changed, 472 insertions(+), 41 deletions(-) create mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java create mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java create mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java new file mode 100644 index 000000000..f274d89f2 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java @@ -0,0 +1,19 @@ +package org.springframework.security.oauth2.server.authorization.util; + +import java.util.UUID; + +import org.springframework.security.crypto.keygen.StringKeyGenerator; + +/** + * @author Paurav Munshi + * @since 0.0.1 + */ +public class AuthorizationCodeKeyGenerator implements StringKeyGenerator { + + @Override + public String generateKey() { + // TODO Auto-generated method stub + return UUID.randomUUID().toString(); + } + +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java new file mode 100644 index 000000000..448ff68b8 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java @@ -0,0 +1,16 @@ +package org.springframework.security.oauth2.server.authorization.util; + +/** + * @author Paurav Munshi + * @since 0.0.1 + */ +public final class OAuth2AuthorizationServerMessages { + + public static final String REQUEST_MISSING_CLIENT_ID = "Request does not contain client id parameter"; + public static final String CLIENT_ID_UNAUTHORIZED_FOR_CODE = "The provided client is not authorized to request authorization code"; + public static final String RESPONSE_TYPE_MISSING_OR_INVALID = "Response type should be present and it should be 'code'"; + public static final String CLIENT_ID_NOT_FOUND = "Can't validate the client id provided with the request"; + public static final String USER_NOT_AUTHENTICATED = "User must be authenticated to perform this action"; + public static final String REDIRECT_URI_MANDATORY_FOR_CLIENT = "Client is configured with multiple URIs. So a specific redirect uri must be supplied with request"; + +} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 412624813..7a5c4f978 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -16,8 +16,6 @@ package org.springframework.security.oauth2.server.authorization.web; import java.io.IOException; -import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; @@ -30,6 +28,8 @@ import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpStatus; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; @@ -42,25 +42,47 @@ import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.util.AuthorizationCodeKeyGenerator; +import org.springframework.security.oauth2.server.authorization.util.OAuth2AuthorizationServerMessages; +import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.util.UriComponentsBuilder; /** * @author Joe Grandja * @author Paurav Munshi + * @since 0.0.1 */ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { + private Converter authorizationRequestConverter; private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; private StringKeyGenerator codeGenerator; private RedirectStrategy authorizationRedirectStrategy; + private RequestMatcher authorizationEndpiontMatcher; + + private static final String DEFAULT_ENDPOINT = "/oauth2/authorize"; + + private static final OAuth2Error CLIENT_ID_ABSENT_ERROR = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REQUEST_MISSING_CLIENT_ID,null); + private static final OAuth2Error REDIRECT_URI_REQUIRED = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT,null); + private static final OAuth2Error CLIENT_ID_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_NOT_FOUND,null); + private static final OAuth2Error USER_NOT_AUTHENTICATED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.USER_NOT_AUTHENTICATED,null); + private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_UNAUTHORIZED_FOR_CODE,null); + private static final OAuth2Error RESPONSE_TYPE_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID,null); - private static final OAuth2Error CLIENT_ID_ABSENT_ERROR = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,"Request does not contain client id parameter",null); - private static final OAuth2Error CLIENT_ID_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED,"Can't validate the client id provided with the request",null); - private static final OAuth2Error RESPONSE_TYPE_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE,"Response type should be present and it should be 'code'",null); - private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT,"The provided client is not authorized to request authorization code",null); + + + public OAuth2AuthorizationEndpointFilter() { + authorizationEndpiontMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT); + authorizationRequestConverter = new OAuth2AuthorizationRequestConverter(); + codeGenerator = new AuthorizationCodeKeyGenerator(); + authorizationRedirectStrategy = new DefaultRedirectStrategy(); + } @Override protected void doFilterInternal(HttpServletRequest request, @@ -72,30 +94,42 @@ protected void doFilterInternal(HttpServletRequest request, OAuth2Authorization authorization = null; try { + checkUserAuthenticated(); client = fetchRegisteredClient(request); authorizationRequest = authorizationRequestConverter.convert(request); - validateAuthorizationRequest(authorizationRequest,client); + validateAuthorizationRequest(request, client); String code = codeGenerator.generateKey(); authorization = buildOAuth2Authorization(client,authorizationRequest,code); authorizationService.save(authorization); - sendCodeOnSuccess(request, response, authorizationRequest, code); + String redirectUri = getRedirectUri(authorizationRequest,client); + sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code); }catch(OAuth2AuthorizationException authorizationException) { OAuth2Error authorizationError = authorizationException.getError(); if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) - || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) + || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) { sendErrorInResponse(response, authorizationError); - - if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE) - || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) - sendErrorInRedirect(request, response, authorizationRequest, authorizationError, authorizationRequest.getRedirectUri()); + } + else if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE) + || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) { + String redirectUri = getRedirectUri(authorizationRequest,client); + sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri); + }else { + throw new ServletException(authorizationException); + } } } + protected void checkUserAuthenticated() { + Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication(); + if(currentAuth==null || !currentAuth.isAuthenticated()) + throw new OAuth2AuthorizationException(USER_NOT_AUTHENTICATED_ERROR); + } + protected RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); if(StringUtils.isEmpty(clientId)) @@ -129,50 +163,66 @@ protected OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, } - protected void validateAuthorizationRequest(OAuth2AuthorizationRequest authzRequest, RegisteredClient client) { - OAuth2AuthorizationResponseType responseType = Optional.ofNullable(authzRequest.getResponseType()) - .orElseThrow(() -> new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR)); + protected void validateAuthorizationRequest(HttpServletRequest request, RegisteredClient client) { + String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); + if(StringUtils.isEmpty(responseType) + || !responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) + throw new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR); - if(!responseType.equals(OAuth2AuthorizationResponseType.CODE)) - throw new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR); - + String redirectUri = request.getParameter(OAuth2ParameterNames.REDIRECT_URI); + if(StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) + throw new OAuth2AuthorizationException(REDIRECT_URI_REQUIRED); + } + + private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { + return !StringUtils.isEmpty(authorizationRequest.getRedirectUri()) + ? authorizationRequest.getRedirectUri() + : client.getRedirectUris().stream().findFirst().get(); } private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response, - OAuth2AuthorizationRequest authorizationRequest, String code) throws IOException { - StringBuilder urlBuilder = new StringBuilder(authorizationRequest.getRedirectUri()) - .append(authorizationRequest.getRedirectUri().contains("?") ? "&" : "?") - .append(OAuth2ParameterNames.CODE).append("=").append(code); + OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException { + UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) + .queryParam(OAuth2ParameterNames.CODE, code); if(!StringUtils.isEmpty(authorizationRequest.getState())) - urlBuilder.append("?").append("state=").append(authorizationRequest.getState()); + redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - String redirectUri = urlBuilder.toString(); - this.authorizationRedirectStrategy.sendRedirect(request, response, redirectUri); + String finalRedirectUri = redirectUriBuilder.toUriString(); + this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri); } private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException { - response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), authorizationError.getErrorCode()+":"+authorizationError.getDescription()); + int errorStatus = -1; + String errorCode = authorizationError.getErrorCode(); + if(errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) + errorStatus=HttpStatus.FORBIDDEN.value(); + else errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value(); + response.sendError(errorStatus, authorizationError.getErrorCode()+":"+authorizationError.getDescription()); } private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, OAuth2AuthorizationRequest authorizationRequest,OAuth2Error authorizationError, String redirectUri) throws IOException { - StringBuilder urlBuilder = new StringBuilder(redirectUri) - .append(redirectUri.contains("?") ? "&" : "?") - .append("error_code=").append(authorizationError.getErrorCode()) - .append("&").append("error_description=").append(authorizationError.getDescription()); + UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) + .queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode()) + .queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, authorizationError.getDescription()); if(!StringUtils.isEmpty(authorizationRequest.getState())) - urlBuilder.append("?").append("state=").append(authorizationRequest.getState()); + redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - String finalRedirectURI = urlBuilder.toString(); + String finalRedirectURI = redirectUriBuilder.toUriString(); this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI); } - + public Converter getAuthorizationRequestConverter() { return authorizationRequestConverter; } + @Override + protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { + return !authorizationEndpiontMatcher.matches(request); + } + public void setAuthorizationRequestConverter( Converter authorizationRequestConverter) { this.authorizationRequestConverter = authorizationRequestConverter; diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java new file mode 100644 index 000000000..859330d80 --- /dev/null +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java @@ -0,0 +1,40 @@ +package org.springframework.security.oauth2.server.authorization.web; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.StringUtils; + +/** + * @author Paurav Munshi + * @since 0.0.1 + * @see Converter + */ +public class OAuth2AuthorizationRequestConverter implements Converter{ + + @Override + public OAuth2AuthorizationRequest convert(HttpServletRequest request) { + String scope = request.getParameter(OAuth2ParameterNames.SCOPE); + Set scopes = !StringUtils.isEmpty(scope) + ? new LinkedHashSet(Arrays.asList(scope.split(" "))) + : Collections.emptySet(); + + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID)) + .redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)) + .scopes(scopes) + .state(request.getParameter(OAuth2ParameterNames.STATE)) + .authorizationUri(request.getServletPath()) + .build(); + + return authorizationRequest; + } + +} diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java index 5aa7bc0f9..7fc52a349 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java @@ -46,4 +46,40 @@ public static RegisteredClient.Builder registeredClient2() { .scope("profile") .scope("email"); } + + public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() { + return RegisteredClient.withId("valid_client_id") + .clientId("valid_client") + .clientSecret("valid_secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("http://localhost:8080/test-application/callback") + .scope("openid") + .scope("profile") + .scope("email"); + } + + public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() { + return RegisteredClient.withId("valid_client_multi_uri_id") + .clientId("valid_client_multi_uri") + .clientSecret("valid_secret") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("http://localhost:8080/test-application/callback") + .redirectUri("http://localhost:8080/another-test-application/callback") + .scope("openid") + .scope("profile") + .scope("email"); + } + + public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() { + return RegisteredClient.withId("valid_cc_client_id") + .clientId("valid_cc_client") + .clientSecret("valid_secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .scope("openid") + .scope("profile") + .scope("email"); + } } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java index 6858349a0..5ff716b09 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -1,20 +1,37 @@ package org.springframework.security.oauth2.server.authorization.web; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; -import java.util.concurrent.atomic.AtomicBoolean; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import org.junit.Before; import org.junit.Test; import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; -import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.util.OAuth2AuthorizationServerMessages; import org.springframework.security.web.RedirectStrategy; @@ -22,10 +39,15 @@ * Tests for {@link OAuth2AuthorizationEndpointFilter}. * * @author Paurav Munshi + * @since 0.0.1 */ public class OAuth2AuthorizationEndpointFilterTest { + private static final String VALID_CLIENT = "valid_client"; + private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri"; + private static final String VALID_CC_CLIENT = "valid_cc_client"; + private OAuth2AuthorizationEndpointFilter filter; private RedirectStrategy authorizationRedirectStrategy = mock(RedirectStrategy.class); @@ -33,26 +55,258 @@ public class OAuth2AuthorizationEndpointFilterTest { private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class); private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class); private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class); + private Authentication authentication = mock(Authentication.class); @Before public void setUp() { filter = new OAuth2AuthorizationEndpointFilter(); - filter.setAuthorizationRequestConverter(authorizationConverter); filter.setAuthorizationService(authorizationService); filter.setCodeGenerator(codeGenerator); filter.setRegisteredClientRepository(registeredClientRepository); - filter.setAuthorizationRedirectStrategy(authorizationRedirectStrategy); + + SecurityContextHolder.getContext().setAuthentication(authentication); + } + + @Test + public void testFilterRedirectsWithCodeOnValidReq() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(codeGenerator.generateKey()).thenReturn("sample_code"); + when(authentication.isAuthenticated()).thenReturn(true); + + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication).isAuthenticated(); + verify(registeredClientRepository).findByClientId(VALID_CLIENT); + verify(authorizationService).save(any(OAuth2Authorization.class)); + verify(codeGenerator).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); + + } + + @Test + public void testFilterRedirectsWithCodeToDefaultRedirectURIWhenNotPresentInRequest() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(codeGenerator.generateKey()).thenReturn("sample_code"); + when(authentication.isAuthenticated()).thenReturn(true); + + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication).isAuthenticated(); + verify(registeredClientRepository).findByClientId(VALID_CLIENT); + verify(authorizationService).save(any(OAuth2Authorization.class)); + verify(codeGenerator).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); + + } + + @Test + public void testErrorWhenRedirectURINotPresentAndClientHasMulitipleUris() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build(); + when(registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient); + when(authentication.isAuthenticated()).thenReturn(true); + + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication, times(1)).isAuthenticated(); + verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI); + verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(codeGenerator, times(0)).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT); + + } + + @Test + public void testErrorClientIdNotSupportAuthorizationGrantFlow() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build(); + when(registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient); + when(authentication.isAuthenticated()).thenReturn(true); + + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication, times(1)).isAuthenticated(); + verify(registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT); + verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(codeGenerator, times(0)).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.CLIENT_ID_UNAUTHORIZED_FOR_CODE); + + } + + @Test + public void testErrorWhenClientIdMissinInRequest() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + when(authentication.isAuthenticated()).thenReturn(true); + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication).isAuthenticated(); + verify(registeredClientRepository, times(0)).findByClientId(anyString()); + verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(codeGenerator, times(0)).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); + assertThat(response.getContentAsString()).isEmpty(); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REQUEST_MISSING_CLIENT_ID); + + } + + @Test + public void testErrorWhenUnregisteredClientInRequest() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null); + when(codeGenerator.generateKey()).thenReturn("sample_code"); + when(authentication.isAuthenticated()).thenReturn(true); + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication).isAuthenticated(); + verify(registeredClientRepository, times(1)).findByClientId("unregistered_client"); + verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(codeGenerator, times(0)).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); + assertThat(response.getContentAsString()).isEmpty(); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.CLIENT_ID_NOT_FOUND); + + } + + @Test + public void testErrorWhenUnauthenticatedUserInRequest() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + when(authentication.isAuthenticated()).thenReturn(false); + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication).isAuthenticated(); + verify(registeredClientRepository, times(0)).findByClientId(anyString()); + verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(codeGenerator, times(0)).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); + assertThat(response.getContentAsString()).isEmpty(); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.USER_NOT_AUTHENTICATED); + + } + + @Test + public void testShouldNotFilterForUnsupportedEndpoint() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setServletPath("/custom/authorize"); + + boolean willFilterGetInvoked = !filter.shouldNotFilter(request); + + assertThat(willFilterGetInvoked).isEqualTo(false); + + } + + @Test + public void testErrorWhenResponseTypeNotPresent() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(codeGenerator.generateKey()).thenReturn("sample_code"); + when(authentication.isAuthenticated()).thenReturn(true); + + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication).isAuthenticated(); + verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); + verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(codeGenerator, times(0)).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).startsWith(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); + assertThat(response.getRedirectedUrl()).contains("error="+OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE); + assertThat(URLDecoder.decode(response.getRedirectedUrl(), StandardCharsets.UTF_8.toString())).contains("error_description="+OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID); + + } + + @Test + public void testErrorWhenResponseTypeIsUnsupported() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(codeGenerator.generateKey()).thenReturn("sample_code"); + when(authentication.isAuthenticated()).thenReturn(true); + + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication).isAuthenticated(); + verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); + verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(codeGenerator, times(0)).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).startsWith(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); + assertThat(response.getRedirectedUrl()).contains("error="+OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE); + assertThat(URLDecoder.decode(response.getRedirectedUrl(), StandardCharsets.UTF_8.toString())).contains("error_description="+OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID); + } @Test public void testSettersAreSettingProperValue() { OAuth2AuthorizationEndpointFilter blankFilter = new OAuth2AuthorizationEndpointFilter(); - assertThat(blankFilter.getAuthorizationRedirectStrategy()).isNull(); - assertThat(blankFilter.getAuthorizationRequestConverter()).isNull(); + assertThat(blankFilter.getAuthorizationRedirectStrategy()).isNotEqualTo(authorizationRedirectStrategy); + assertThat(blankFilter.getAuthorizationRequestConverter()).isNotEqualTo(authorizationConverter); assertThat(blankFilter.getAuthorizationService()).isNull(); - assertThat(blankFilter.getCodeGenerator()).isNull(); + assertThat(blankFilter.getCodeGenerator()).isNotEqualTo(codeGenerator); assertThat(blankFilter.getRegisteredClientRepository()).isNull(); blankFilter.setAuthorizationRequestConverter(authorizationConverter); @@ -67,5 +321,21 @@ public void testSettersAreSettingProperValue() { assertThat(blankFilter.getCodeGenerator()).isEqualTo(codeGenerator); assertThat(blankFilter.getRegisteredClientRepository()).isEqualTo(registeredClientRepository); } + + + private MockHttpServletRequest getValidMockHttpServletRequest() { + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT); + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code"); + request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email"); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback"); + request.setParameter(OAuth2ParameterNames.STATE, "teststate"); + request.setServletPath("/oauth2/authorize"); + + return request; + + + } } From 29defd4c3c26d7bad5434fbbff3528d8512738f1 Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Mon, 4 May 2020 17:46:44 -0400 Subject: [PATCH 05/10] Resolved checkstyle issues - Added space after commans where ever applicable - Removed trailing whilte spaces - Added white space after if condition - Added headers before package declaration --- .../util/AuthorizationCodeKeyGenerator.java | 15 ++ .../OAuth2AuthorizationServerMessages.java | 15 ++ .../OAuth2AuthorizationEndpointFilter.java | 120 ++++++------ .../OAuth2AuthorizationRequestConverter.java | 19 +- .../client/TestRegisteredClients.java | 4 +- ...OAuth2AuthorizationEndpointFilterTest.java | 171 ++++++++++-------- 6 files changed, 202 insertions(+), 142 deletions(-) diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java index f274d89f2..30ff2be46 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java @@ -1,3 +1,18 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.security.oauth2.server.authorization.util; import java.util.UUID; diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java index 448ff68b8..1ad458470 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java @@ -1,3 +1,18 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.security.oauth2.server.authorization.util; /** diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 7a5c4f978..aaff392c0 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -58,64 +58,64 @@ * @since 0.0.1 */ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { - + private Converter authorizationRequestConverter; private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; private StringKeyGenerator codeGenerator; private RedirectStrategy authorizationRedirectStrategy; private RequestMatcher authorizationEndpiontMatcher; - - private static final String DEFAULT_ENDPOINT = "/oauth2/authorize"; - - private static final OAuth2Error CLIENT_ID_ABSENT_ERROR = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REQUEST_MISSING_CLIENT_ID,null); - private static final OAuth2Error REDIRECT_URI_REQUIRED = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT,null); - private static final OAuth2Error CLIENT_ID_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_NOT_FOUND,null); - private static final OAuth2Error USER_NOT_AUTHENTICATED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.USER_NOT_AUTHENTICATED,null); - private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_UNAUTHORIZED_FOR_CODE,null); - private static final OAuth2Error RESPONSE_TYPE_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID,null); - - - + + private static final String DEFAULT_ENDPOINT = "/oauth2/authorize"; + + private static final OAuth2Error CLIENT_ID_ABSENT_ERROR = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REQUEST_MISSING_CLIENT_ID, null); + private static final OAuth2Error REDIRECT_URI_REQUIRED = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT, null); + private static final OAuth2Error CLIENT_ID_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_NOT_FOUND, null); + private static final OAuth2Error USER_NOT_AUTHENTICATED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.USER_NOT_AUTHENTICATED, null); + private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_UNAUTHORIZED_FOR_CODE, null); + private static final OAuth2Error RESPONSE_TYPE_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID, null); + + + public OAuth2AuthorizationEndpointFilter() { authorizationEndpiontMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT); authorizationRequestConverter = new OAuth2AuthorizationRequestConverter(); codeGenerator = new AuthorizationCodeKeyGenerator(); authorizationRedirectStrategy = new DefaultRedirectStrategy(); } - + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - + RegisteredClient client = null; OAuth2AuthorizationRequest authorizationRequest = null; OAuth2Authorization authorization = null; - + try { checkUserAuthenticated(); client = fetchRegisteredClient(request); - + authorizationRequest = authorizationRequestConverter.convert(request); validateAuthorizationRequest(request, client); - + String code = codeGenerator.generateKey(); - authorization = buildOAuth2Authorization(client,authorizationRequest,code); + authorization = buildOAuth2Authorization(client, authorizationRequest, code); authorizationService.save(authorization); - - String redirectUri = getRedirectUri(authorizationRequest,client); + + String redirectUri = getRedirectUri(authorizationRequest, client); sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code); }catch(OAuth2AuthorizationException authorizationException) { OAuth2Error authorizationError = authorizationException.getError(); - - if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) + + if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) { sendErrorInResponse(response, authorizationError); } - else if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE) + else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE) || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) { - String redirectUri = getRedirectUri(authorizationRequest,client); + String redirectUri = getRedirectUri(authorizationRequest, client); sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri); }else { throw new ServletException(authorizationException); @@ -123,32 +123,32 @@ else if(authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RE } } - + protected void checkUserAuthenticated() { Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication(); - if(currentAuth==null || !currentAuth.isAuthenticated()) + if (currentAuth==null || !currentAuth.isAuthenticated()) throw new OAuth2AuthorizationException(USER_NOT_AUTHENTICATED_ERROR); } - + protected RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); - if(StringUtils.isEmpty(clientId)) + if (StringUtils.isEmpty(clientId)) throw new OAuth2AuthorizationException(CLIENT_ID_ABSENT_ERROR); - + RegisteredClient client = registeredClientRepository.findByClientId(clientId); - if(client==null) + if (client==null) throw new OAuth2AuthorizationException(CLIENT_ID_NOT_FOUND_ERROR); - + boolean isAuthoirzationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes()) .anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE)); - if(!isAuthoirzationGrantAllowed) + if (!isAuthoirzationGrantAllowed) throw new OAuth2AuthorizationException(AUTHZ_CODE_NOT_SUPPORTED_ERROR); - + return client; - + } - - protected OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, + + protected OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, OAuth2AuthorizationRequest authorizationRequest, String code) { OAuth2Authorization authorization = OAuth2Authorization.createBuilder() .clientId(authorizationRequest.getClientId()) @@ -158,62 +158,62 @@ protected OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, .addAttribute(OAuth2ParameterNames.SCOPE, Optional.ofNullable(authorizationRequest.getScopes()) .filter(scopes -> !scopes.isEmpty()).orElse(client.getScopes())) .build(); - + return authorization; } - - + + protected void validateAuthorizationRequest(HttpServletRequest request, RegisteredClient client) { String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); - if(StringUtils.isEmpty(responseType) + if (StringUtils.isEmpty(responseType) || !responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) - throw new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR); - + throw new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR); + String redirectUri = request.getParameter(OAuth2ParameterNames.REDIRECT_URI); - if(StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) + if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) throw new OAuth2AuthorizationException(REDIRECT_URI_REQUIRED); } - + private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { - return !StringUtils.isEmpty(authorizationRequest.getRedirectUri()) - ? authorizationRequest.getRedirectUri() + return !StringUtils.isEmpty(authorizationRequest.getRedirectUri()) + ? authorizationRequest.getRedirectUri() : client.getRedirectUris().stream().findFirst().get(); } - + private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response, OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException { UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) .queryParam(OAuth2ParameterNames.CODE, code); - if(!StringUtils.isEmpty(authorizationRequest.getState())) + if (!StringUtils.isEmpty(authorizationRequest.getState())) redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - + String finalRedirectUri = redirectUriBuilder.toUriString(); this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri); } - + private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException { int errorStatus = -1; String errorCode = authorizationError.getErrorCode(); - if(errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) + if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) errorStatus=HttpStatus.FORBIDDEN.value(); else errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value(); response.sendError(errorStatus, authorizationError.getErrorCode()+":"+authorizationError.getDescription()); } - + private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, - OAuth2AuthorizationRequest authorizationRequest,OAuth2Error authorizationError, + OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError, String redirectUri) throws IOException { UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) .queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode()) .queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, authorizationError.getDescription()); - - if(!StringUtils.isEmpty(authorizationRequest.getState())) + + if (!StringUtils.isEmpty(authorizationRequest.getState())) redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - + String finalRedirectURI = redirectUriBuilder.toUriString(); this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI); } - + public Converter getAuthorizationRequestConverter() { return authorizationRequestConverter; } @@ -255,9 +255,9 @@ public void setCodeGenerator(StringKeyGenerator codeGenerator) { public RedirectStrategy getAuthorizationRedirectStrategy() { return authorizationRedirectStrategy; } - + public void setAuthorizationRedirectStrategy(RedirectStrategy redirectStrategy) { this.authorizationRedirectStrategy = redirectStrategy; } - + } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java index 859330d80..357bacc8f 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java @@ -1,3 +1,18 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.security.oauth2.server.authorization.web; import java.util.Arrays; @@ -25,7 +40,7 @@ public OAuth2AuthorizationRequest convert(HttpServletRequest request) { Set scopes = !StringUtils.isEmpty(scope) ? new LinkedHashSet(Arrays.asList(scope.split(" "))) : Collections.emptySet(); - + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID)) .redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)) @@ -33,7 +48,7 @@ public OAuth2AuthorizationRequest convert(HttpServletRequest request) { .state(request.getParameter(OAuth2ParameterNames.STATE)) .authorizationUri(request.getServletPath()) .build(); - + return authorizationRequest; } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java index 7fc52a349..502fa4822 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java @@ -58,7 +58,7 @@ public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() .scope("profile") .scope("email"); } - + public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() { return RegisteredClient.withId("valid_client_multi_uri_id") .clientId("valid_client_multi_uri") @@ -71,7 +71,7 @@ public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirec .scope("profile") .scope("email"); } - + public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() { return RegisteredClient.withId("valid_cc_client_id") .clientId("valid_cc_client") diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java index 5ff716b09..62dcfcbe6 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.security.oauth2.server.authorization.web; import static org.assertj.core.api.Assertions.assertThat; @@ -43,80 +58,80 @@ */ public class OAuth2AuthorizationEndpointFilterTest { - + private static final String VALID_CLIENT = "valid_client"; private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri"; private static final String VALID_CC_CLIENT = "valid_cc_client"; private OAuth2AuthorizationEndpointFilter filter; - + private RedirectStrategy authorizationRedirectStrategy = mock(RedirectStrategy.class); private Converter authorizationConverter = mock(Converter.class); private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class); private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class); private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class); private Authentication authentication = mock(Authentication.class); - + @Before public void setUp() { filter = new OAuth2AuthorizationEndpointFilter(); - + filter.setAuthorizationService(authorizationService); filter.setCodeGenerator(codeGenerator); filter.setRegisteredClientRepository(registeredClientRepository); - + SecurityContextHolder.getContext().setAuthentication(authentication); } - + @Test public void testFilterRedirectsWithCodeOnValidReq() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); when(codeGenerator.generateKey()).thenReturn("sample_code"); when(authentication.isAuthenticated()).thenReturn(true); - - + + filter.doFilterInternal(request, response, filterChain); - + verify(authentication).isAuthenticated(); verify(registeredClientRepository).findByClientId(VALID_CLIENT); verify(authorizationService).save(any(OAuth2Authorization.class)); verify(codeGenerator).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); - + } - + @Test public void testFilterRedirectsWithCodeToDefaultRedirectURIWhenNotPresentInRequest() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); when(codeGenerator.generateKey()).thenReturn("sample_code"); when(authentication.isAuthenticated()).thenReturn(true); - - + + filter.doFilterInternal(request, response, filterChain); - + verify(authentication).isAuthenticated(); verify(registeredClientRepository).findByClientId(VALID_CLIENT); verify(authorizationService).save(any(OAuth2Authorization.class)); verify(codeGenerator).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); - + } - + @Test public void testErrorWhenRedirectURINotPresentAndClientHasMulitipleUris() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); @@ -124,207 +139,207 @@ public void testErrorWhenRedirectURINotPresentAndClientHasMulitipleUris() throws request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build(); when(registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient); when(authentication.isAuthenticated()).thenReturn(true); - - + + filter.doFilterInternal(request, response, filterChain); - + verify(authentication, times(1)).isAuthenticated(); verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI); verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); verify(codeGenerator, times(0)).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT); - + } - + @Test public void testErrorClientIdNotSupportAuthorizationGrantFlow() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build(); when(registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient); when(authentication.isAuthenticated()).thenReturn(true); - - + + filter.doFilterInternal(request, response, filterChain); - + verify(authentication, times(1)).isAuthenticated(); verify(registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT); verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); verify(codeGenerator, times(0)).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.CLIENT_ID_UNAUTHORIZED_FOR_CODE); - + } - + @Test public void testErrorWhenClientIdMissinInRequest() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.CLIENT_ID, ""); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + when(authentication.isAuthenticated()).thenReturn(true); - + filter.doFilterInternal(request, response, filterChain); - + verify(authentication).isAuthenticated(); verify(registeredClientRepository, times(0)).findByClientId(anyString()); verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); verify(codeGenerator, times(0)).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); assertThat(response.getContentAsString()).isEmpty(); assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REQUEST_MISSING_CLIENT_ID); - + } - + @Test public void testErrorWhenUnregisteredClientInRequest() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); when(registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null); when(codeGenerator.generateKey()).thenReturn("sample_code"); when(authentication.isAuthenticated()).thenReturn(true); - + filter.doFilterInternal(request, response, filterChain); - + verify(authentication).isAuthenticated(); verify(registeredClientRepository, times(1)).findByClientId("unregistered_client"); verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); verify(codeGenerator, times(0)).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); assertThat(response.getContentAsString()).isEmpty(); assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.CLIENT_ID_NOT_FOUND); - + } - + @Test public void testErrorWhenUnauthenticatedUserInRequest() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + when(authentication.isAuthenticated()).thenReturn(false); - + filter.doFilterInternal(request, response, filterChain); - + verify(authentication).isAuthenticated(); verify(registeredClientRepository, times(0)).findByClientId(anyString()); verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); verify(codeGenerator, times(0)).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); assertThat(response.getContentAsString()).isEmpty(); assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.USER_NOT_AUTHENTICATED); - + } - + @Test public void testShouldNotFilterForUnsupportedEndpoint() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setServletPath("/custom/authorize"); - + boolean willFilterGetInvoked = !filter.shouldNotFilter(request); - + assertThat(willFilterGetInvoked).isEqualTo(false); - + } - + @Test public void testErrorWhenResponseTypeNotPresent() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, ""); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); when(codeGenerator.generateKey()).thenReturn("sample_code"); when(authentication.isAuthenticated()).thenReturn(true); - - + + filter.doFilterInternal(request, response, filterChain); - + verify(authentication).isAuthenticated(); verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); verify(codeGenerator, times(0)).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).startsWith(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); assertThat(response.getRedirectedUrl()).contains("error="+OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE); assertThat(URLDecoder.decode(response.getRedirectedUrl(), StandardCharsets.UTF_8.toString())).contains("error_description="+OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID); - + } - + @Test public void testErrorWhenResponseTypeIsUnsupported() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); when(codeGenerator.generateKey()).thenReturn("sample_code"); when(authentication.isAuthenticated()).thenReturn(true); - - + + filter.doFilterInternal(request, response, filterChain); - + verify(authentication).isAuthenticated(); verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); verify(codeGenerator, times(0)).generateKey(); - + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).startsWith(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); assertThat(response.getRedirectedUrl()).contains("error="+OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE); assertThat(URLDecoder.decode(response.getRedirectedUrl(), StandardCharsets.UTF_8.toString())).contains("error_description="+OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID); - + } @Test public void testSettersAreSettingProperValue() { OAuth2AuthorizationEndpointFilter blankFilter = new OAuth2AuthorizationEndpointFilter(); - + assertThat(blankFilter.getAuthorizationRedirectStrategy()).isNotEqualTo(authorizationRedirectStrategy); assertThat(blankFilter.getAuthorizationRequestConverter()).isNotEqualTo(authorizationConverter); assertThat(blankFilter.getAuthorizationService()).isNull(); assertThat(blankFilter.getCodeGenerator()).isNotEqualTo(codeGenerator); assertThat(blankFilter.getRegisteredClientRepository()).isNull(); - + blankFilter.setAuthorizationRequestConverter(authorizationConverter); blankFilter.setAuthorizationService(authorizationService); blankFilter.setCodeGenerator(codeGenerator); blankFilter.setRegisteredClientRepository(registeredClientRepository); blankFilter.setAuthorizationRedirectStrategy(authorizationRedirectStrategy); - + assertThat(blankFilter.getAuthorizationRedirectStrategy()).isEqualTo(authorizationRedirectStrategy); assertThat(blankFilter.getAuthorizationRequestConverter()).isEqualTo(authorizationConverter); assertThat(blankFilter.getAuthorizationService()).isEqualTo(authorizationService); assertThat(blankFilter.getCodeGenerator()).isEqualTo(codeGenerator); assertThat(blankFilter.getRegisteredClientRepository()).isEqualTo(registeredClientRepository); } - - + + private MockHttpServletRequest getValidMockHttpServletRequest() { - + MockHttpServletRequest request = new MockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT); request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code"); @@ -332,10 +347,10 @@ private MockHttpServletRequest getValidMockHttpServletRequest() { request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback"); request.setParameter(OAuth2ParameterNames.STATE, "teststate"); request.setServletPath("/oauth2/authorize"); - + return request; - - + + } } From 3834f1747fcc753c833fc4a7e410c71a4b56fb1a Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Mon, 4 May 2020 18:51:15 -0400 Subject: [PATCH 06/10] Added check for same redirect uri should be configured in client id --- .../OAuth2AuthorizationServerMessages.java | 1 + .../OAuth2AuthorizationEndpointFilter.java | 3 +++ ...OAuth2AuthorizationEndpointFilterTest.java | 24 +++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java index 1ad458470..762ec05b2 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java @@ -27,5 +27,6 @@ public final class OAuth2AuthorizationServerMessages { public static final String CLIENT_ID_NOT_FOUND = "Can't validate the client id provided with the request"; public static final String USER_NOT_AUTHENTICATED = "User must be authenticated to perform this action"; public static final String REDIRECT_URI_MANDATORY_FOR_CLIENT = "Client is configured with multiple URIs. So a specific redirect uri must be supplied with request"; + public static final String REQUESTED_REDIRECT_URI_INVALID = "Requested redirect uri is invalid."; } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index aaff392c0..4ecf8be99 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -70,6 +70,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { private static final OAuth2Error CLIENT_ID_ABSENT_ERROR = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REQUEST_MISSING_CLIENT_ID, null); private static final OAuth2Error REDIRECT_URI_REQUIRED = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT, null); + private static final OAuth2Error INVALID_REDIRECT_URI_REQUESTED = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REQUESTED_REDIRECT_URI_INVALID, null); private static final OAuth2Error CLIENT_ID_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_NOT_FOUND, null); private static final OAuth2Error USER_NOT_AUTHENTICATED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.USER_NOT_AUTHENTICATED, null); private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_UNAUTHORIZED_FOR_CODE, null); @@ -172,6 +173,8 @@ protected void validateAuthorizationRequest(HttpServletRequest request, Register String redirectUri = request.getParameter(OAuth2ParameterNames.REDIRECT_URI); if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) throw new OAuth2AuthorizationException(REDIRECT_URI_REQUIRED); + if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) + throw new OAuth2AuthorizationException(INVALID_REDIRECT_URI_REQUESTED); } private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java index 62dcfcbe6..e080fa04c 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -156,6 +156,30 @@ public void testErrorWhenRedirectURINotPresentAndClientHasMulitipleUris() throws assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT); } + + @Test + public void testErrorWhenRequestedRedirectUriNotConfiguredInClient() throws Exception { + MockHttpServletRequest request = getValidMockHttpServletRequest(); + request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); + when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(authentication.isAuthenticated()).thenReturn(true); + + + filter.doFilterInternal(request, response, filterChain); + + verify(authentication, times(1)).isAuthenticated(); + verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); + verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(codeGenerator, times(0)).generateKey(); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REQUESTED_REDIRECT_URI_INVALID); + + } @Test public void testErrorClientIdNotSupportAuthorizationGrantFlow() throws Exception { From 4018cb22378fef2cfd947febd6b7e271203f337d Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Mon, 4 May 2020 18:54:24 -0400 Subject: [PATCH 07/10] Fixed checkstyle issue in redirect uri test in Test class --- .../web/OAuth2AuthorizationEndpointFilterTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java index e080fa04c..87c8cd740 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -156,7 +156,7 @@ public void testErrorWhenRedirectURINotPresentAndClientHasMulitipleUris() throws assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT); } - + @Test public void testErrorWhenRequestedRedirectUriNotConfiguredInClient() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); From 357dfad75fc2c712f7cfeb82de10c89db640bcf8 Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Mon, 11 May 2020 11:08:11 -0400 Subject: [PATCH 08/10] Enhancements & fix the code formatting & style issues - Removed Constants from the filter class - Remove ServerMessages file - Removed AuthorizationCodeKeyGenerator instead used Base64StringKeyGenerator - Removed getter methods - Only provide setters for default and options fields - Added mandatory fields to ctor - Add braces to single statement if conditions - Adopted method declaration seq to be ctor, setter, overridden methods and private methods - Used this keyword where ever access instance members - Added tests for ctor and setters - Renamed tests to methodName-When-Then format --- .../util/AuthorizationCodeKeyGenerator.java | 34 --- .../OAuth2AuthorizationServerMessages.java | 32 -- .../OAuth2AuthorizationEndpointFilter.java | 187 +++++------- .../OAuth2AuthorizationRequestConverter.java | 2 +- ...OAuth2AuthorizationEndpointFilterTest.java | 282 +++++++++--------- 5 files changed, 230 insertions(+), 307 deletions(-) delete mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java delete mode 100644 core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java deleted file mode 100644 index 30ff2be46..000000000 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/AuthorizationCodeKeyGenerator.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.util; - -import java.util.UUID; - -import org.springframework.security.crypto.keygen.StringKeyGenerator; - -/** - * @author Paurav Munshi - * @since 0.0.1 - */ -public class AuthorizationCodeKeyGenerator implements StringKeyGenerator { - - @Override - public String generateKey() { - // TODO Auto-generated method stub - return UUID.randomUUID().toString(); - } - -} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java deleted file mode 100644 index 762ec05b2..000000000 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/util/OAuth2AuthorizationServerMessages.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.util; - -/** - * @author Paurav Munshi - * @since 0.0.1 - */ -public final class OAuth2AuthorizationServerMessages { - - public static final String REQUEST_MISSING_CLIENT_ID = "Request does not contain client id parameter"; - public static final String CLIENT_ID_UNAUTHORIZED_FOR_CODE = "The provided client is not authorized to request authorization code"; - public static final String RESPONSE_TYPE_MISSING_OR_INVALID = "Response type should be present and it should be 'code'"; - public static final String CLIENT_ID_NOT_FOUND = "Can't validate the client id provided with the request"; - public static final String USER_NOT_AUTHENTICATED = "User must be authenticated to perform this action"; - public static final String REDIRECT_URI_MANDATORY_FOR_CLIENT = "Client is configured with multiple URIs. So a specific redirect uri must be supplied with request"; - public static final String REQUESTED_REDIRECT_URI_INVALID = "Requested redirect uri is invalid."; - -} diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 4ecf8be99..313fc7510 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -16,9 +16,6 @@ package org.springframework.security.oauth2.server.authorization.web; import java.io.IOException; -import java.time.Instant; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Stream; import javax.servlet.FilterChain; @@ -30,6 +27,7 @@ import org.springframework.http.HttpStatus; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; @@ -42,12 +40,11 @@ import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; -import org.springframework.security.oauth2.server.authorization.util.AuthorizationCodeKeyGenerator; -import org.springframework.security.oauth2.server.authorization.util.OAuth2AuthorizationServerMessages; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.util.UriComponentsBuilder; @@ -59,30 +56,47 @@ */ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter { - private Converter authorizationRequestConverter; + private static final String DEFAULT_ENDPOINT = "/oauth2/authorize"; + + private Converter authorizationRequestConverter = new OAuth2AuthorizationRequestConverter(); private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; - private StringKeyGenerator codeGenerator; - private RedirectStrategy authorizationRedirectStrategy; - private RequestMatcher authorizationEndpiontMatcher; + private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(); + private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); + private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT); + + public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null."); + Assert.notNull(authorizationService, "authorizationService cannot be null."); + this.registeredClientRepository = registeredClientRepository; + this.authorizationService = authorizationService; + } - private static final String DEFAULT_ENDPOINT = "/oauth2/authorize"; + public final void setAuthorizationRequestConverter( + Converter authorizationRequestConverter) { + Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null"); + this.authorizationRequestConverter = authorizationRequestConverter; + } - private static final OAuth2Error CLIENT_ID_ABSENT_ERROR = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REQUEST_MISSING_CLIENT_ID, null); - private static final OAuth2Error REDIRECT_URI_REQUIRED = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT, null); - private static final OAuth2Error INVALID_REDIRECT_URI_REQUESTED = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2AuthorizationServerMessages.REQUESTED_REDIRECT_URI_INVALID, null); - private static final OAuth2Error CLIENT_ID_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_NOT_FOUND, null); - private static final OAuth2Error USER_NOT_AUTHENTICATED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.USER_NOT_AUTHENTICATED, null); - private static final OAuth2Error AUTHZ_CODE_NOT_SUPPORTED_ERROR = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2AuthorizationServerMessages.CLIENT_ID_UNAUTHORIZED_FOR_CODE, null); - private static final OAuth2Error RESPONSE_TYPE_NOT_FOUND_ERROR = new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID, null); + public final void setCodeGenerator(StringKeyGenerator codeGenerator) { + Assert.notNull(codeGenerator, "codeGenerator cannot be set to null"); + this.codeGenerator = codeGenerator; + } + public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) { + Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null"); + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + } + public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) { + Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null"); + this.authorizationEndpointMatcher = authorizationEndpointMatcher; + } - public OAuth2AuthorizationEndpointFilter() { - authorizationEndpiontMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT); - authorizationRequestConverter = new OAuth2AuthorizationRequestConverter(); - codeGenerator = new AuthorizationCodeKeyGenerator(); - authorizationRedirectStrategy = new DefaultRedirectStrategy(); + @Override + protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { + return !this.authorizationEndpointMatcher.matches(request); } @Override @@ -98,16 +112,17 @@ protected void doFilterInternal(HttpServletRequest request, checkUserAuthenticated(); client = fetchRegisteredClient(request); - authorizationRequest = authorizationRequestConverter.convert(request); + authorizationRequest = this.authorizationRequestConverter.convert(request); validateAuthorizationRequest(request, client); - String code = codeGenerator.generateKey(); + String code = this.codeGenerator.generateKey(); authorization = buildOAuth2Authorization(client, authorizationRequest, code); - authorizationService.save(authorization); + this.authorizationService.save(authorization); String redirectUri = getRedirectUri(authorizationRequest, client); sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code); - }catch(OAuth2AuthorizationException authorizationException) { + } + catch(OAuth2AuthorizationException authorizationException) { OAuth2Error authorizationError = authorizationException.getError(); if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) @@ -118,63 +133,68 @@ else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_R || authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) { String redirectUri = getRedirectUri(authorizationRequest, client); sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri); - }else { + } + else { throw new ServletException(authorizationException); } } } - protected void checkUserAuthenticated() { + private void checkUserAuthenticated() { Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication(); - if (currentAuth==null || !currentAuth.isAuthenticated()) - throw new OAuth2AuthorizationException(USER_NOT_AUTHENTICATED_ERROR); + if (currentAuth==null || !currentAuth.isAuthenticated()) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); + } } - protected RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { + private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException { String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); - if (StringUtils.isEmpty(clientId)) - throw new OAuth2AuthorizationException(CLIENT_ID_ABSENT_ERROR); + if (StringUtils.isEmpty(clientId)) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + } - RegisteredClient client = registeredClientRepository.findByClientId(clientId); - if (client==null) - throw new OAuth2AuthorizationException(CLIENT_ID_NOT_FOUND_ERROR); + RegisteredClient client = this.registeredClientRepository.findByClientId(clientId); + if (client==null) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); + } - boolean isAuthoirzationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes()) + boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes()) .anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE)); - if (!isAuthoirzationGrantAllowed) - throw new OAuth2AuthorizationException(AUTHZ_CODE_NOT_SUPPORTED_ERROR); + if (!isAuthorizationGrantAllowed) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED)); + } return client; } - protected OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, + private OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, OAuth2AuthorizationRequest authorizationRequest, String code) { OAuth2Authorization authorization = OAuth2Authorization.createBuilder() .clientId(authorizationRequest.getClientId()) .addAttribute(OAuth2ParameterNames.CODE, code) - .addAttribute(OAuth2Authorization.ISSUED_AT, Instant.now()) - .addAttribute(OAuth2Authorization.CODE_USED, new AtomicBoolean(false)) - .addAttribute(OAuth2ParameterNames.SCOPE, Optional.ofNullable(authorizationRequest.getScopes()) - .filter(scopes -> !scopes.isEmpty()).orElse(client.getScopes())) + .attribures(authorizationRequest.getAttributes()) .build(); return authorization; } - protected void validateAuthorizationRequest(HttpServletRequest request, RegisteredClient client) { + private void validateAuthorizationRequest(HttpServletRequest request, RegisteredClient client) { String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); if (StringUtils.isEmpty(responseType) - || !responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) - throw new OAuth2AuthorizationException(RESPONSE_TYPE_NOT_FOUND_ERROR); + || !responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)); + } String redirectUri = request.getParameter(OAuth2ParameterNames.REDIRECT_URI); - if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) - throw new OAuth2AuthorizationException(REDIRECT_URI_REQUIRED); - if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) - throw new OAuth2AuthorizationException(INVALID_REDIRECT_URI_REQUESTED); + if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + } + if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) { + throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + } } private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { @@ -187,8 +207,9 @@ private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse r OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException { UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) .queryParam(OAuth2ParameterNames.CODE, code); - if (!StringUtils.isEmpty(authorizationRequest.getState())) + if (!StringUtils.isEmpty(authorizationRequest.getState())) { redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + } String finalRedirectUri = redirectUriBuilder.toUriString(); this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri); @@ -197,70 +218,26 @@ private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse r private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException { int errorStatus = -1; String errorCode = authorizationError.getErrorCode(); - if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) + if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) { errorStatus=HttpStatus.FORBIDDEN.value(); - else errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value(); - response.sendError(errorStatus, authorizationError.getErrorCode()+":"+authorizationError.getDescription()); + } + else { + errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value(); + } + response.sendError(errorStatus, authorizationError.getErrorCode()); } private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response, OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError, String redirectUri) throws IOException { UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri) - .queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode()) - .queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, authorizationError.getDescription()); + .queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode()); - if (!StringUtils.isEmpty(authorizationRequest.getState())) + if (!StringUtils.isEmpty(authorizationRequest.getState())) { redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + } String finalRedirectURI = redirectUriBuilder.toUriString(); this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI); } - - public Converter getAuthorizationRequestConverter() { - return authorizationRequestConverter; - } - - @Override - protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { - return !authorizationEndpiontMatcher.matches(request); - } - - public void setAuthorizationRequestConverter( - Converter authorizationRequestConverter) { - this.authorizationRequestConverter = authorizationRequestConverter; - } - - public RegisteredClientRepository getRegisteredClientRepository() { - return registeredClientRepository; - } - - public void setRegisteredClientRepository(RegisteredClientRepository registeredClientRepository) { - this.registeredClientRepository = registeredClientRepository; - } - - public OAuth2AuthorizationService getAuthorizationService() { - return authorizationService; - } - - public void setAuthorizationService(OAuth2AuthorizationService authorizationService) { - this.authorizationService = authorizationService; - } - - public StringKeyGenerator getCodeGenerator() { - return codeGenerator; - } - - public void setCodeGenerator(StringKeyGenerator codeGenerator) { - this.codeGenerator = codeGenerator; - } - - public RedirectStrategy getAuthorizationRedirectStrategy() { - return authorizationRedirectStrategy; - } - - public void setAuthorizationRedirectStrategy(RedirectStrategy redirectStrategy) { - this.authorizationRedirectStrategy = redirectStrategy; - } - } diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java index 357bacc8f..619a74229 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java @@ -32,7 +32,7 @@ * @since 0.0.1 * @see Converter */ -public class OAuth2AuthorizationRequestConverter implements Converter{ +public class OAuth2AuthorizationRequestConverter implements Converter { @Override public OAuth2AuthorizationRequest convert(HttpServletRequest request) { diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java index 87c8cd740..e826bd010 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -16,22 +16,21 @@ package org.springframework.security.oauth2.server.authorization.web; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.net.URLDecoder; -import java.nio.charset.StandardCharsets; - import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.junit.Before; import org.junit.Test; -import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpStatus; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -39,15 +38,12 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import org.springframework.security.oauth2.server.authorization.util.OAuth2AuthorizationServerMessages; -import org.springframework.security.web.RedirectStrategy; /** @@ -65,8 +61,6 @@ public class OAuth2AuthorizationEndpointFilterTest { private OAuth2AuthorizationEndpointFilter filter; - private RedirectStrategy authorizationRedirectStrategy = mock(RedirectStrategy.class); - private Converter authorizationConverter = mock(Converter.class); private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class); private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class); private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class); @@ -74,33 +68,67 @@ public class OAuth2AuthorizationEndpointFilterTest { @Before public void setUp() { - filter = new OAuth2AuthorizationEndpointFilter(); + this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService); + this.filter.setCodeGenerator(this.codeGenerator); + + SecurityContextHolder.getContext().setAuthentication(this.authentication); + } + + @Test + public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception { + assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception { + assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { + assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { + assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null)) + .isInstanceOf(IllegalArgumentException.class); + } - filter.setAuthorizationService(authorizationService); - filter.setCodeGenerator(codeGenerator); - filter.setRegisteredClientRepository(registeredClientRepository); + @Test + public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { + assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } - SecurityContextHolder.getContext().setAuthentication(authentication); + @Test + public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { + assertThatThrownBy(() ->this.filter.setCodeGenerator(null)) + .isInstanceOf(IllegalArgumentException.class); } @Test - public void testFilterRedirectsWithCodeOnValidReq() throws Exception { + public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(codeGenerator.generateKey()).thenReturn("sample_code"); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.isAuthenticated()).thenReturn(true); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication).isAuthenticated(); - verify(registeredClientRepository).findByClientId(VALID_CLIENT); - verify(authorizationService).save(any(OAuth2Authorization.class)); - verify(codeGenerator).generateKey(); + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository).findByClientId(VALID_CLIENT); + verify(this.authorizationService).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); @@ -108,24 +136,24 @@ public void testFilterRedirectsWithCodeOnValidReq() throws Exception { } @Test - public void testFilterRedirectsWithCodeToDefaultRedirectURIWhenNotPresentInRequest() throws Exception { + public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(codeGenerator.generateKey()).thenReturn("sample_code"); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.isAuthenticated()).thenReturn(true); + this.filter.doFilter(request, response, filterChain); - filter.doFilterInternal(request, response, filterChain); - - verify(authentication).isAuthenticated(); - verify(registeredClientRepository).findByClientId(VALID_CLIENT); - verify(authorizationService).save(any(OAuth2Authorization.class)); - verify(codeGenerator).generateKey(); + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository).findByClientId(VALID_CLIENT); + verify(this.authorizationService).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); @@ -133,7 +161,7 @@ public void testFilterRedirectsWithCodeToDefaultRedirectURIWhenNotPresentInReque } @Test - public void testErrorWhenRedirectURINotPresentAndClientHasMulitipleUris() throws Exception { + public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI); request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); @@ -141,227 +169,211 @@ public void testErrorWhenRedirectURINotPresentAndClientHasMulitipleUris() throws FilterChain filterChain = mock(FilterChain.class); RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build(); - when(registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient); + when(this.authentication.isAuthenticated()).thenReturn(true); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication, times(1)).isAuthenticated(); - verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI); - verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(codeGenerator, times(0)).generateKey(); + verify(this.authentication, times(1)).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REDIRECT_URI_MANDATORY_FOR_CLIENT); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); } @Test - public void testErrorWhenRequestedRedirectUriNotConfiguredInClient() throws Exception { + public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(this.authentication.isAuthenticated()).thenReturn(true); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication, times(1)).isAuthenticated(); - verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); - verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(codeGenerator, times(0)).generateKey(); + verify(this.authentication, times(1)).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REQUESTED_REDIRECT_URI_INVALID); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); } @Test - public void testErrorClientIdNotSupportAuthorizationGrantFlow() throws Exception { + public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build(); - when(registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient); + when(this.authentication.isAuthenticated()).thenReturn(true); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication, times(1)).isAuthenticated(); - verify(registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT); - verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(codeGenerator, times(0)).generateKey(); + verify(this.authentication, times(1)).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.CLIENT_ID_UNAUTHORIZED_FOR_CODE); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); } @Test - public void testErrorWhenClientIdMissinInRequest() throws Exception { + public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.CLIENT_ID, ""); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.authentication.isAuthenticated()).thenReturn(true); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication).isAuthenticated(); - verify(registeredClientRepository, times(0)).findByClientId(anyString()); - verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(codeGenerator, times(0)).generateKey(); + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository, times(0)).findByClientId(anyString()); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); assertThat(response.getContentAsString()).isEmpty(); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST+":"+OAuth2AuthorizationServerMessages.REQUEST_MISSING_CLIENT_ID); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); } @Test - public void testErrorWhenUnregisteredClientInRequest() throws Exception { + public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null); - when(codeGenerator.generateKey()).thenReturn("sample_code"); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null); + when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.isAuthenticated()).thenReturn(true); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication).isAuthenticated(); - verify(registeredClientRepository, times(1)).findByClientId("unregistered_client"); - verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(codeGenerator, times(0)).generateKey(); + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client"); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); assertThat(response.getContentAsString()).isEmpty(); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.CLIENT_ID_NOT_FOUND); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); } @Test - public void testErrorWhenUnauthenticatedUserInRequest() throws Exception { + public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); when(authentication.isAuthenticated()).thenReturn(false); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication).isAuthenticated(); - verify(registeredClientRepository, times(0)).findByClientId(anyString()); - verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(codeGenerator, times(0)).generateKey(); + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository, times(0)).findByClientId(anyString()); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); assertThat(response.getContentAsString()).isEmpty(); - assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED+":"+OAuth2AuthorizationServerMessages.USER_NOT_AUTHENTICATED); + assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); } @Test - public void testShouldNotFilterForUnsupportedEndpoint() throws Exception { + public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setServletPath("/custom/authorize"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); - boolean willFilterGetInvoked = !filter.shouldNotFilter(request); - - assertThat(willFilterGetInvoked).isEqualTo(false); + OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); + spyFilter.doFilter(request, response, filterChain); + verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); + verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); } @Test - public void testErrorWhenResponseTypeNotPresent() throws Exception { + public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, ""); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(codeGenerator.generateKey()).thenReturn("sample_code"); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.isAuthenticated()).thenReturn(true); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication).isAuthenticated(); - verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); - verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(codeGenerator, times(0)).generateKey(); + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).startsWith(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); assertThat(response.getRedirectedUrl()).contains("error="+OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE); - assertThat(URLDecoder.decode(response.getRedirectedUrl(), StandardCharsets.UTF_8.toString())).contains("error_description="+OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID); } @Test - public void testErrorWhenResponseTypeIsUnsupported() throws Exception { + public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception { MockHttpServletRequest request = getValidMockHttpServletRequest(); request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(codeGenerator.generateKey()).thenReturn("sample_code"); - when(authentication.isAuthenticated()).thenReturn(true); + when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); + when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.isAuthenticated()).thenReturn(true); - filter.doFilterInternal(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); - verify(authentication).isAuthenticated(); - verify(registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); - verify(authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(codeGenerator, times(0)).generateKey(); + verify(this.authentication).isAuthenticated(); + verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); + verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); + verify(this.codeGenerator, times(0)).generateKey(); + verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).startsWith(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); assertThat(response.getRedirectedUrl()).contains("error="+OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE); - assertThat(URLDecoder.decode(response.getRedirectedUrl(), StandardCharsets.UTF_8.toString())).contains("error_description="+OAuth2AuthorizationServerMessages.RESPONSE_TYPE_MISSING_OR_INVALID); - } - @Test - public void testSettersAreSettingProperValue() { - OAuth2AuthorizationEndpointFilter blankFilter = new OAuth2AuthorizationEndpointFilter(); - - assertThat(blankFilter.getAuthorizationRedirectStrategy()).isNotEqualTo(authorizationRedirectStrategy); - assertThat(blankFilter.getAuthorizationRequestConverter()).isNotEqualTo(authorizationConverter); - assertThat(blankFilter.getAuthorizationService()).isNull(); - assertThat(blankFilter.getCodeGenerator()).isNotEqualTo(codeGenerator); - assertThat(blankFilter.getRegisteredClientRepository()).isNull(); - - blankFilter.setAuthorizationRequestConverter(authorizationConverter); - blankFilter.setAuthorizationService(authorizationService); - blankFilter.setCodeGenerator(codeGenerator); - blankFilter.setRegisteredClientRepository(registeredClientRepository); - blankFilter.setAuthorizationRedirectStrategy(authorizationRedirectStrategy); - - assertThat(blankFilter.getAuthorizationRedirectStrategy()).isEqualTo(authorizationRedirectStrategy); - assertThat(blankFilter.getAuthorizationRequestConverter()).isEqualTo(authorizationConverter); - assertThat(blankFilter.getAuthorizationService()).isEqualTo(authorizationService); - assertThat(blankFilter.getCodeGenerator()).isEqualTo(codeGenerator); - assertThat(blankFilter.getRegisteredClientRepository()).isEqualTo(registeredClientRepository); - } - - private MockHttpServletRequest getValidMockHttpServletRequest() { MockHttpServletRequest request = new MockHttpServletRequest(); From 7a3e65882ce02dc9440121542c3fb0d8cffdd403 Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Fri, 15 May 2020 13:30:05 -0400 Subject: [PATCH 09/10] Moving response_type checks in requestMatcher - Removed response type checks from validation method and added to RequestMatcher - Changed the validation method to accept OAuth2AuthorizationRequest instead of HttpServletRequest - Changed the response_type tests to expect proceeding with filter chain instead of getting an error --- .../OAuth2AuthorizationEndpointFilter.java | 22 +++++----- ...OAuth2AuthorizationEndpointFilterTest.java | 43 +++++-------------- 2 files changed, 21 insertions(+), 44 deletions(-) diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 313fc7510..08be62965 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -34,7 +34,6 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -96,7 +95,14 @@ public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEn @Override protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { - return !this.authorizationEndpointMatcher.matches(request); + boolean pathMatch = this.authorizationEndpointMatcher.matches(request); + String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); + boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType); + if (pathMatch && responseTypeMatch) { + return false; + }else { + return true; + } } @Override @@ -113,7 +119,7 @@ protected void doFilterInternal(HttpServletRequest request, client = fetchRegisteredClient(request); authorizationRequest = this.authorizationRequestConverter.convert(request); - validateAuthorizationRequest(request, client); + validateAuthorizationRequest(authorizationRequest, client); String code = this.codeGenerator.generateKey(); authorization = buildOAuth2Authorization(client, authorizationRequest, code); @@ -181,14 +187,8 @@ private OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, } - private void validateAuthorizationRequest(HttpServletRequest request, RegisteredClient client) { - String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); - if (StringUtils.isEmpty(responseType) - || !responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) { - throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)); - } - - String redirectUri = request.getParameter(OAuth2ParameterNames.REDIRECT_URI); + private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) { + String redirectUri = authorizationRequest.getRedirectUri(); if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) { throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); } diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java index e826bd010..4bc214647 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -328,24 +328,12 @@ public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedire MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(this.codeGenerator.generateKey()).thenReturn("sample_code"); - when(this.authentication.isAuthenticated()).thenReturn(true); - - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication).isAuthenticated(); - verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); - verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator, times(0)).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).startsWith(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); - assertThat(response.getRedirectedUrl()).contains("error="+OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE); + OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); + spyFilter.doFilter(request, response, filterChain); + verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); + verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); + verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test @@ -355,23 +343,12 @@ public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedir MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); - when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); - when(this.codeGenerator.generateKey()).thenReturn("sample_code"); - when(this.authentication.isAuthenticated()).thenReturn(true); - - - this.filter.doFilter(request, response, filterChain); - - verify(this.authentication).isAuthenticated(); - verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); - verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); - verify(this.codeGenerator, times(0)).generateKey(); - verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); + spyFilter.doFilter(request, response, filterChain); - assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).startsWith(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); - assertThat(response.getRedirectedUrl()).contains("error="+OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE); + verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); + verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); + verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } private MockHttpServletRequest getValidMockHttpServletRequest() { From 2326b7579d1d05649de0132de891535170267840 Mon Sep 17 00:00:00 2001 From: Paurav Munshi Date: Mon, 18 May 2020 13:57:35 -0400 Subject: [PATCH 10/10] Changes to accomodated updated OAuth2Authorization - Updated the api to build OAuth2Authorization in endpoint filter - Changed the api to create the builder for OAuth2Authorization - Changed the api to add attributes to OAuth2Authorization - Changed the key of attribute with which authorization code was added - Set up test user to be returned by mock object --- .../web/OAuth2AuthorizationEndpointFilter.java | 14 ++++++++------ .../web/OAuth2AuthorizationEndpointFilterTest.java | 2 ++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 08be62965..270e92dd0 100644 --- a/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -37,6 +37,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.web.DefaultRedirectStrategy; @@ -116,13 +117,14 @@ protected void doFilterInternal(HttpServletRequest request, try { checkUserAuthenticated(); + Authentication auth = SecurityContextHolder.getContext().getAuthentication(); client = fetchRegisteredClient(request); authorizationRequest = this.authorizationRequestConverter.convert(request); validateAuthorizationRequest(authorizationRequest, client); String code = this.codeGenerator.generateKey(); - authorization = buildOAuth2Authorization(client, authorizationRequest, code); + authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code); this.authorizationService.save(authorization); String redirectUri = getRedirectUri(authorizationRequest, client); @@ -175,12 +177,12 @@ private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throw } - private OAuth2Authorization buildOAuth2Authorization(RegisteredClient client, + private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client, OAuth2AuthorizationRequest authorizationRequest, String code) { - OAuth2Authorization authorization = OAuth2Authorization.createBuilder() - .clientId(authorizationRequest.getClientId()) - .addAttribute(OAuth2ParameterNames.CODE, code) - .attribures(authorizationRequest.getAttributes()) + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client) + .principalName(auth.getPrincipal().toString()) + .attribute(TokenType.AUTHORIZATION_CODE.getValue(), code) + .attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes())) .build(); return authorization; diff --git a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java index 4bc214647..4b2e86a1e 100644 --- a/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java +++ b/core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java @@ -119,6 +119,7 @@ public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectUR RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.getPrincipal()).thenReturn("test-user"); when(this.authentication.isAuthenticated()).thenReturn(true); @@ -145,6 +146,7 @@ public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRe RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); when(this.codeGenerator.generateKey()).thenReturn("sample_code"); + when(this.authentication.getPrincipal()).thenReturn("test-user"); when(this.authentication.isAuthenticated()).thenReturn(true); this.filter.doFilter(request, response, filterChain);