Skip to content

Allow configurable refresh token strategy for authorization_code grant #1432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
Expand Down Expand Up @@ -213,24 +212,23 @@ public Authentication authenticate(Authentication authentication) throws Authent

// ----- Refresh token -----
OAuth2RefreshToken refreshToken = null;
if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN) &&
// Do not issue refresh token to public client
!clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {

if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the refresh token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
if (generatedRefreshToken != null) {
if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate a valid refresh token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}

if (this.logger.isTraceEnabled()) {
this.logger.trace("Generated refresh token");
}
if (this.logger.isTraceEnabled()) {
this.logger.trace("Generated refresh token");
}

refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
authorizationBuilder.refreshToken(refreshToken);
refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
authorizationBuilder.refreshToken(refreshToken);
}
}

// ----- ID token -----
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,8 +21,11 @@
import org.springframework.lang.Nullable;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;

/**
* An {@link OAuth2TokenGenerator} that generates an {@link OAuth2RefreshToken}.
Expand All @@ -42,9 +45,22 @@ public OAuth2RefreshToken generate(OAuth2TokenContext context) {
if (!OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
return null;
}
if (isPublicClientForAuthorizationCodeGrant(context)) {
// Do not issue refresh token to public client
return null;
}

Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(context.getRegisteredClient().getTokenSettings().getRefreshTokenTimeToLive());
return new OAuth2RefreshToken(this.refreshTokenGenerator.generateKey(), issuedAt, expiresAt);
}

private static boolean isPublicClientForAuthorizationCodeGrant(OAuth2TokenContext context) {
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) &&
(context.getAuthorizationGrant().getPrincipal() instanceof OAuth2ClientAuthenticationToken clientPrincipal)) {
return clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE);
}
return false;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2Token;
Expand Down Expand Up @@ -371,7 +372,7 @@ public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2Authentication
}

@Test
public void authenticateWhenRefreshTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
public void authenticateWhenInvalidRefreshTokenGeneratedThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
Expand All @@ -389,7 +390,8 @@ public void authenticateWhenRefreshTokenNotGeneratedThenThrowOAuth2Authenticatio
doAnswer(answer -> {
OAuth2TokenContext context = answer.getArgument(0);
if (OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
return null;
return new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
} else {
return answer.callRealMethod();
}
Expand All @@ -400,7 +402,7 @@ public void authenticateWhenRefreshTokenNotGeneratedThenThrowOAuth2Authenticatio
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
assertThat(error.getDescription()).contains("The token generator failed to generate the refresh token.");
assertThat(error.getDescription()).contains("The token generator failed to generate a valid refresh token.");
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.lang.Nullable;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationProvider;
Expand All @@ -65,9 +66,12 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.crypto.password.NoOpPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
Expand Down Expand Up @@ -426,6 +430,54 @@ public void requestWhenPublicClientWithPkceThenReturnAccessTokenResponse() throw
assertThat(authorizationCodeToken.getMetadata().get(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME)).isEqualTo(true);
}

// gh-1430
@Test
public void requestWhenPublicClientWithPkceAndCustomRefreshTokenGeneratorThenReturnRefreshToken() throws Exception {
this.spring.register(AuthorizationServerConfigurationWithCustomRefreshTokenGenerator.class).autowire();

RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
.build();
this.registeredClientRepository.save(registeredClient);

MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient))
.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED);

String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
assertThat(authorizationCodeAuthorization).isNotNull();
assertThat(authorizationCodeAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);

this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER))
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
.andExpect(status().isOk())
.andExpect(jsonPath("$.access_token").isNotEmpty())
.andExpect(jsonPath("$.token_type").isNotEmpty())
.andExpect(jsonPath("$.expires_in").isNotEmpty())
.andExpect(jsonPath("$.refresh_token").isNotEmpty())
.andExpect(jsonPath("$.scope").isNotEmpty());

OAuth2Authorization authorization = this.authorizationService.findById(authorizationCodeAuthorization.getId());
assertThat(authorization).isNotNull();
assertThat(authorization.getAccessToken()).isNotNull();
assertThat(authorization.getRefreshToken()).isNotNull();

OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCodeToken = authorization.getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCodeToken).isNotNull();
assertThat(authorizationCodeToken.getMetadata().get(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME)).isEqualTo(true);
}

@Test
public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenBadRequest() throws Exception {
this.spring.register(AuthorizationServerConfiguration.class).autowire();
Expand Down Expand Up @@ -896,6 +948,42 @@ static class ParametersMapper extends JdbcOAuth2AuthorizationService.OAuth2Autho

}

@EnableWebSecurity
@Import(OAuth2AuthorizationServerConfiguration.class)
static class AuthorizationServerConfigurationWithCustomRefreshTokenGenerator extends AuthorizationServerConfiguration {

@Bean
JwtEncoder jwtEncoder() {
return jwtEncoder;
}

@Bean
OAuth2TokenGenerator<?> tokenGenerator() {
JwtGenerator jwtGenerator = new JwtGenerator(jwtEncoder());
jwtGenerator.setJwtCustomizer(jwtCustomizer());
OAuth2TokenGenerator<OAuth2RefreshToken> refreshTokenGenerator = new CustomRefreshTokenGenerator();
return new DelegatingOAuth2TokenGenerator(jwtGenerator, refreshTokenGenerator);
}

private static final class CustomRefreshTokenGenerator implements OAuth2TokenGenerator<OAuth2RefreshToken> {
private final StringKeyGenerator refreshTokenGenerator =
new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);

@Nullable
@Override
public OAuth2RefreshToken generate(OAuth2TokenContext context) {
if (!OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
return null;
}
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(context.getRegisteredClient().getTokenSettings().getRefreshTokenTimeToLive());
return new OAuth2RefreshToken(this.refreshTokenGenerator.generateKey(), issuedAt, expiresAt);
}

}

}

@EnableWebSecurity
@Configuration(proxyBeanMethods = false)
static class AuthorizationServerConfigurationWithSecurityContextRepository extends AuthorizationServerConfiguration {
Expand Down
Loading