diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java index 28ba4d3d5..0cc68d44f 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java @@ -89,11 +89,14 @@ public Jwt generate(OAuth2TokenContext context) { Instant issuedAt = Instant.now(); Instant expiresAt; + JwsHeader.Builder headersBuilder; if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) { // TODO Allow configuration for ID Token time-to-live expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); + headersBuilder = JwsHeader.with(registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm()); } else { expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive()); + headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256); } // @formatter:off @@ -125,8 +128,6 @@ public Jwt generate(OAuth2TokenContext context) { } // @formatter:on - JwsHeader.Builder headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256); - if (this.jwtCustomizer != null) { // @formatter:off JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(headersBuilder, claimsBuilder) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java index 727bfa4c6..712aa3b04 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java @@ -200,9 +200,6 @@ private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) { ArgumentCaptor jwtEncoderParametersCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class); verify(this.jwtEncoder).encode(jwtEncoderParametersCaptor.capture()); - JwsHeader jwsHeader = jwtEncoderParametersCaptor.getValue().getJwsHeader(); - assertThat(jwsHeader.getAlgorithm()).isEqualTo(SignatureAlgorithm.RS256); - JwtClaimsSet jwtClaimsSet = jwtEncoderParametersCaptor.getValue().getClaims(); assertThat(jwtClaimsSet.getIssuer().toExternalForm()).isEqualTo(tokenContext.getProviderContext().getIssuer()); assertThat(jwtClaimsSet.getSubject()).isEqualTo(tokenContext.getAuthorization().getPrincipalName()); @@ -210,14 +207,20 @@ private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) { Instant issuedAt = Instant.now(); Instant expiresAt; + JwsHeader.Builder headersBuilder; if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) { expiresAt = issuedAt.plus(tokenContext.getRegisteredClient().getTokenSettings().getAccessTokenTimeToLive()); + headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256); } else { expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); + headersBuilder = JwsHeader.with(tokenContext.getRegisteredClient().getTokenSettings().getIdTokenSignatureAlgorithm()); } assertThat(jwtClaimsSet.getIssuedAt()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1)); assertThat(jwtClaimsSet.getExpiresAt()).isBetween(expiresAt.minusSeconds(1), expiresAt.plusSeconds(1)); + JwsHeader jwsHeader = jwtEncoderParametersCaptor.getValue().getJwsHeader(); + assertThat(jwsHeader.getAlgorithm()).isEqualTo(headersBuilder.build().getAlgorithm()); + if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) { assertThat(jwtClaimsSet.getNotBefore()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));