Skip to content

Add Refresh Token grant support #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.web.JwkSetEndpointFilter;
Expand Down Expand Up @@ -149,6 +150,10 @@ public void init(B builder) {
jwtEncoder);
builder.authenticationProvider(postProcess(clientCredentialsAuthenticationProvider));

OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider =
new OAuth2RefreshTokenAuthenticationProvider(getAuthorizationService(builder), jwtEncoder);
builder.authenticationProvider(postProcess(refreshTokenAuthenticationProvider));

OAuth2TokenRevocationAuthenticationProvider tokenRevocationAuthenticationProvider =
new OAuth2TokenRevocationAuthenticationProvider(
getAuthorizationService(builder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,18 @@ private boolean hasToken(OAuth2Authorization authorization, String token, TokenT
} else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
return authorizationCode != null && authorizationCode.getTokenValue().equals(token);
} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
}

if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
return authorization.getTokens().getAccessToken() != null &&
authorization.getTokens().getAccessToken().getTokenValue().equals(token);
}

if (TokenType.REFRESH_TOKEN.equals(tokenType)) {
return authorization.getTokens().getRefreshToken() != null &&
authorization.getTokens().getRefreshToken().getTokenValue().equals(token);
}

return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
*/
package org.springframework.security.oauth2.server.authorization.authentication;

import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;

Expand All @@ -41,6 +43,7 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication
private final RegisteredClient registeredClient;
private final Authentication clientPrincipal;
private final OAuth2AccessToken accessToken;
private final OAuth2RefreshToken refreshToken;

/**
* Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters.
Expand All @@ -51,13 +54,27 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication
*/
public OAuth2AccessTokenAuthenticationToken(RegisteredClient registeredClient,
Authentication clientPrincipal, OAuth2AccessToken accessToken) {
this(registeredClient, clientPrincipal, accessToken, null);
}

/**
* Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters.
*
* @param registeredClient the registered client
* @param clientPrincipal the authenticated client principal
* @param accessToken the access token
* @param refreshToken the refresh token
*/
public OAuth2AccessTokenAuthenticationToken(RegisteredClient registeredClient,
Authentication clientPrincipal, OAuth2AccessToken accessToken, @Nullable OAuth2RefreshToken refreshToken) {
super(Collections.emptyList());
Assert.notNull(registeredClient, "registeredClient cannot be null");
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
Assert.notNull(accessToken, "accessToken cannot be null");
this.registeredClient = registeredClient;
this.clientPrincipal = clientPrincipal;
this.accessToken = accessToken;
this.refreshToken = refreshToken;
}

@Override
Expand Down Expand Up @@ -87,4 +104,15 @@ public RegisteredClient getRegisteredClient() {
public OAuth2AccessToken getAccessToken() {
return this.accessToken;
}


/**
* Returns the {@link OAuth2RefreshToken} if provided
*
* @return the {@link OAuth2RefreshToken}
*/
@Nullable
public OAuth2RefreshToken getRefreshToken() {
return refreshToken;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.JoseHeader;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
Expand All @@ -41,12 +38,6 @@
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.Set;

/**
Expand Down Expand Up @@ -128,37 +119,24 @@ public Authentication authenticate(Authentication authentication) throws Authent
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();

// TODO Allow configuration for issuer claim
URL issuer = null;
try {
issuer = URI.create("https://oauth2.provider.com").toURL();
} catch (MalformedURLException e) { }

Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token time-to-live
Set<String> authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES);
Jwt jwt = OAuth2TokenIssuerUtil
.issueJwtAccessToken(this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), authorizedScopes);
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), authorizedScopes);

JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims()
.issuer(issuer)
.subject(authorization.getPrincipalName())
.audience(Collections.singletonList(registeredClient.getClientId()))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.notBefore(issuedAt)
.claim(OAuth2ParameterNames.SCOPE, authorizedScopes)
.build();

Jwt jwt = this.jwtEncoder.encode(joseHeader, jwtClaimsSet);
OAuth2Tokens.Builder tokensBuilder = OAuth2Tokens.from(authorization.getTokens())
.accessToken(accessToken);

OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
OAuth2RefreshToken refreshToken = null;
if (registeredClient.getTokenSettings().enableRefreshTokens()) {
refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(registeredClient.getTokenSettings().refreshTokenTimeToLive());
tokensBuilder.refreshToken(refreshToken);
}

OAuth2Tokens tokens = tokensBuilder.build();
authorization = OAuth2Authorization.from(authorization)
.tokens(OAuth2Tokens.from(authorization.getTokens())
.accessToken(accessToken)
.build())
.tokens(tokens)
.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
.build();

Expand All @@ -167,7 +145,7 @@ public Authentication authenticate(Authentication authentication) throws Authent

this.authorizationService.save(authorization);

return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken, refreshToken);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.JoseHeader;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
Expand All @@ -36,12 +32,6 @@
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -101,29 +91,8 @@ public Authentication authenticate(Authentication authentication) throws Authent
scopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes());
}

JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();

// TODO Allow configuration for issuer claim
URL issuer = null;
try {
issuer = URI.create("https://oauth2.provider.com").toURL();
} catch (MalformedURLException e) { }

Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); // TODO Allow configuration for access token time-to-live

JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims()
.issuer(issuer)
.subject(clientPrincipal.getName())
.audience(Collections.singletonList(registeredClient.getClientId()))
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.notBefore(issuedAt)
.claim(OAuth2ParameterNames.SCOPE, scopes)
.build();

Jwt jwt = this.jwtEncoder.encode(joseHeader, jwtClaimsSet);

Jwt jwt = OAuth2TokenIssuerUtil
.issueJwtAccessToken(this.jwtEncoder, clientPrincipal.getName(), registeredClient.getClientId(), scopes);
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.server.authorization.authentication;

import java.time.Instant;
import java.util.Set;

import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
import org.springframework.util.Assert;

/**
* An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant.
*
* @author Alexey Nesterov
* @since 0.0.3
* @see OAuth2RefreshTokenAuthenticationToken
* @see OAuth2AccessTokenAuthenticationToken
* @see OAuth2AuthorizationService
* @see JwtEncoder
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.5">Section 1.5 Refresh Token</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
*/

public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {

private final OAuth2AuthorizationService authorizationService;
private final JwtEncoder jwtEncoder;

public OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) {
Assert.notNull(authorizationService, "authorizationService cannot be null");
Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");

this.authorizationService = authorizationService;
this.jwtEncoder = jwtEncoder;
}

@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2RefreshTokenAuthenticationToken refreshTokenAuthentication =
(OAuth2RefreshTokenAuthenticationToken) authentication;

OAuth2ClientAuthenticationToken clientPrincipal = null;
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(refreshTokenAuthentication.getPrincipal().getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) refreshTokenAuthentication.getPrincipal();
}

if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
}

OAuth2Authorization authorization = this.authorizationService.findByToken(refreshTokenAuthentication.getRefreshToken(), TokenType.REFRESH_TOKEN);
if (authorization == null) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

Instant refreshTokenExpiration = authorization.getTokens().getRefreshToken().getExpiresAt();
if (refreshTokenExpiration.isBefore(Instant.now())) {
// as per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization
// code, resource owner credentials) or refresh token is invalid, expired, revoked [...].
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
}

RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();

// https://tools.ietf.org/html/rfc6749#section-6
// The requested scope MUST NOT include any scope not originally granted by the resource owner,
// and if omitted is treated as equal to the scope originally granted by the resource owner.
Set<String> refreshTokenScopes = refreshTokenAuthentication.getScopes();
Set<String> authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES);
if (!authorizedScopes.containsAll(refreshTokenScopes)) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_SCOPE));
}

if (refreshTokenScopes.isEmpty()) {
refreshTokenScopes = authorizedScopes;
}

Jwt jwt = OAuth2TokenIssuerUtil
.issueJwtAccessToken(this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), refreshTokenScopes);
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), refreshTokenScopes);

TokenSettings tokenSettings = registeredClient.getTokenSettings();
OAuth2RefreshToken refreshToken;
if (tokenSettings.reuseRefreshTokens()) {
refreshToken = authorization.getTokens().getRefreshToken();
} else {
refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(tokenSettings.refreshTokenTimeToLive());
}

authorization = OAuth2Authorization.from(authorization)
.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
.tokens(OAuth2Tokens.builder().accessToken(accessToken).refreshToken(refreshToken).build())
.build();

this.authorizationService.save(authorization);

return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken, refreshToken);
}

@Override
public boolean supports(Class<?> authentication) {
return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
}
}
Loading