From 63056995f38a8b5fb7052978c90459cc260b77e0 Mon Sep 17 00:00:00 2001 From: cbilodeau Date: Tue, 27 Jun 2023 11:33:34 -0400 Subject: [PATCH] Fix generating refresh_token with null 'sid' --- .../authorization/token/JwtGenerator.java | 4 +- ...freshTokenAuthenticationProviderTests.java | 82 ++++++++++++++++++- 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java index 2bbff980a..b02b2ca34 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java @@ -134,7 +134,9 @@ public Jwt generate(OAuth2TokenContext context) { } } else if (AuthorizationGrantType.REFRESH_TOKEN.equals(context.getAuthorizationGrantType())) { OidcIdToken currentIdToken = context.getAuthorization().getToken(OidcIdToken.class).getToken(); - claimsBuilder.claim("sid", currentIdToken.getClaim("sid")); + if (currentIdToken.hasClaim("sid")) { + claimsBuilder.claim("sid", currentIdToken.getClaim("sid")); + } claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, currentIdToken.getClaim(IdTokenClaimNames.AUTH_TIME)); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java index 61f0452da..3f9981e7e 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java @@ -47,6 +47,7 @@ import org.springframework.security.oauth2.jwt.JoseHeaderNames; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.jwt.JwtEncoderParameters; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; @@ -250,7 +251,86 @@ public void authenticateWhenValidRefreshTokenThenReturnIdToken() { assertThat(idTokenContext.getJwsHeader()).isNotNull(); assertThat(idTokenContext.getClaims()).isNotNull(); - verify(this.jwtEncoder, times(2)).encode(any()); // Access token and ID Token + ArgumentCaptor jwtEncoderParametersArgumentCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class); + verify(this.jwtEncoder, times(2)).encode(jwtEncoderParametersArgumentCaptor.capture()); // Access token and ID Token + JwtEncoderParameters jwtEncoderParameters = jwtEncoderParametersArgumentCaptor.getValue(); + assertThat(jwtEncoderParameters.getClaims().getClaims().get("sid")).isNotNull(); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); + assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken()); + assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken()); + OAuth2Authorization.Token idToken = updatedAuthorization.getToken(OidcIdToken.class); + assertThat(idToken).isNotNull(); + assertThat(accessTokenAuthentication.getAdditionalParameters()) + .containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue())); + assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken()); + // By default, refresh token is reused + assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken()); + } + + @Test + public void authenticateWhenValidRefreshTokenThenReturnIdTokenWithoutSid() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); + OidcIdToken authorizedIdToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject("subject") + .issuedAt(Instant.now()) + .expiresAt(Instant.now().plusSeconds(60)) + .claim(IdTokenClaimNames.AUTH_TIME, Date.from(Instant.now())) + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).token(authorizedIdToken).build(); + when(this.authorizationService.findByToken( + eq(authorization.getRefreshToken().getToken().getTokenValue()), + eq(OAuth2TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); + verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture()); + // Access Token context + JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0); + assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(accessTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); + assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(accessTokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes()); + assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN); + assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); + assertThat(accessTokenContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(accessTokenContext.getJwsHeader()).isNotNull(); + assertThat(accessTokenContext.getClaims()).isNotNull(); + Map claims = new HashMap<>(); + accessTokenContext.getClaims().claims(claims::putAll); + assertThat(claims).flatExtracting(OAuth2ParameterNames.SCOPE) + .containsExactlyInAnyOrder(OidcScopes.OPENID, "scope1"); + // ID Token context + JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1); + assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(idTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); + assertThat(idTokenContext.getAuthorization()).isNotEqualTo(authorization); + assertThat(idTokenContext.getAuthorization().getAccessToken()).isNotEqualTo(authorization.getAccessToken()); + assertThat(idTokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes()); + assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN); + assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); + assertThat(idTokenContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(idTokenContext.getJwsHeader()).isNotNull(); + assertThat(idTokenContext.getClaims()).isNotNull(); + + ArgumentCaptor jwtEncoderParametersArgumentCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class); + verify(this.jwtEncoder, times(2)).encode(jwtEncoderParametersArgumentCaptor.capture()); // Access token and ID Token + JwtEncoderParameters jwtEncoderParameters = jwtEncoderParametersArgumentCaptor.getValue(); + assertThat(jwtEncoderParameters.getClaims().getClaims().get("sid")).isNull(); ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(this.authorizationService).save(authorizationCaptor.capture());