diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java index fab21e865..a37a9d55b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java @@ -74,6 +74,25 @@ public OAuth2AuthorizationCodeAuthenticationToken(String code, this.redirectUri = redirectUri; } + /** + * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters. + * + * @param code the authorization code + * @param clientId the client identifier + * @param clientPrincipal the authenticated client principal + * @param redirectUri the redirect uri + */ + public OAuth2AuthorizationCodeAuthenticationToken(String code, + String clientId, Authentication clientPrincipal, @Nullable String redirectUri) { + super(Collections.emptyList()); + Assert.hasText(code, "code cannot be empty"); + Assert.notNull(clientPrincipal, "clientPrincipal cannot be null"); + this.code = code; + this.clientId = clientId; + this.clientPrincipal = clientPrincipal; + this.redirectUri = redirectUri; + } + @Override public Object getPrincipal() { return this.clientPrincipal != null ? this.clientPrincipal : this.clientId; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 3680d6634..6fafa853a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -197,16 +197,14 @@ public Authentication convert(HttpServletRequest request) { MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // client_id (REQUIRED) + // client_id (REQUIRED, if the client is not authenticating with the authorization server) String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); - Authentication clientPrincipal = null; if (StringUtils.hasText(clientId)) { if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); } - } else { - clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); } + Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); // code (REQUIRED) String code = parameters.getFirst(OAuth2ParameterNames.CODE); @@ -223,9 +221,7 @@ public Authentication convert(HttpServletRequest request) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI); } - return clientPrincipal != null ? - new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) : - new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri); + return new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, clientPrincipal, redirectUri); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java index 0ecd1472e..7ebcd7726 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java @@ -257,6 +257,57 @@ public void doFilterWhenAuthorizationCodeTokenRequestValidThenAccessTokenRespons assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); } + @Test + public void doFilterWhenAuthorizationCodeAndClientIdTokenRequestValidThenAccessTokenResponse() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plus(Duration.ofHours(1)), + new HashSet<>(Arrays.asList("scope1", "scope2"))); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + new OAuth2AccessTokenAuthenticationToken( + registeredClient, clientPrincipal, accessToken); + + when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + + MockHttpServletRequest request = createAuthorizationCodeAndClientIdTokenRequest(registeredClient); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + ArgumentCaptor authorizationCodeAuthenticationCaptor = + ArgumentCaptor.forClass(OAuth2AuthorizationCodeAuthenticationToken.class); + verify(this.authenticationManager).authenticate(authorizationCodeAuthenticationCaptor.capture()); + + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = + authorizationCodeAuthenticationCaptor.getValue(); + assertThat(authorizationCodeAuthentication.getCode()).isEqualTo( + request.getParameter(OAuth2ParameterNames.CODE)); + assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo( + request.getParameter(OAuth2ParameterNames.REDIRECT_URI)); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response); + + OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken(); + assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType()); + assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue()); + assertThat(accessTokenResult.getIssuedAt()).isBetween( + accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1)); + assertThat(accessTokenResult.getExpiresAt()).isBetween( + accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1)); + assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes()); + } + @Test public void doFilterWhenTokenRequestMultipleScopeThenInvalidRequestError() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); @@ -363,6 +414,21 @@ private static MockHttpServletRequest createAuthorizationCodeTokenRequest(Regist return request; } + private static MockHttpServletRequest createAuthorizationCodeAndClientIdTokenRequest(RegisteredClient registeredClient) { + String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]); + + String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + + request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + + return request; + } + private static MockHttpServletRequest createClientCredentialsTokenRequest(RegisteredClient registeredClient) { String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);