Skip to content

Commit 5a94d18

Browse files
committed
Implemented code+tests for both imperative & reactive codelines for issue spring-projects#6609
1 parent d86550f commit 5a94d18

File tree

5 files changed

+199
-27
lines changed

5 files changed

+199
-27
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java

+42-5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
3333
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
3434
import org.springframework.security.oauth2.core.AuthorizationGrantType;
35+
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
3536
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
3637
import org.springframework.util.Assert;
3738
import org.springframework.util.StringUtils;
@@ -42,6 +43,9 @@
4243

4344
import javax.servlet.http.HttpServletRequest;
4445
import javax.servlet.http.HttpServletResponse;
46+
import java.time.Clock;
47+
import java.time.Duration;
48+
import java.time.Instant;
4549

4650
/**
4751
* An implementation of a {@link HandlerMethodArgumentResolver} that is capable
@@ -69,6 +73,10 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
6973
private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
7074
new DefaultClientCredentialsTokenResponseClient();
7175

76+
private Clock clock = Clock.systemUTC();
77+
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
78+
79+
7280
/**
7381
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
7482
*
@@ -105,18 +113,29 @@ public Object resolveArgument(MethodParameter parameter,
105113
"@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
106114
}
107115

116+
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
117+
if (clientRegistration == null) {
118+
return null;
119+
}
120+
108121
Authentication principal = SecurityContextHolder.getContext().getAuthentication();
109122
HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
110123

111124
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
112125
clientRegistrationId, principal, servletRequest);
113126
if (authorizedClient != null) {
114-
return authorizedClient;
115-
}
127+
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
128+
// MH TODO: Refresh token
129+
}
116130

117-
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
118-
if (clientRegistration == null) {
119-
return null;
131+
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
132+
if (hasTokenExpired(authorizedClient)) {
133+
HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class);
134+
authorizedClient = this.authorizeClientCredentialsClient(clientRegistration, servletRequest, servletResponse);
135+
}
136+
}
137+
138+
return authorizedClient;
120139
}
121140

122141
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
@@ -172,6 +191,24 @@ private OAuth2AuthorizedClient authorizeClientCredentialsClient(ClientRegistrati
172191
return authorizedClient;
173192
}
174193

194+
private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) {
195+
if (this.authorizedClientRepository == null) {
196+
return false;
197+
}
198+
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
199+
if (refreshToken == null) {
200+
return false;
201+
}
202+
return hasTokenExpired(authorizedClient);
203+
}
204+
205+
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
206+
Instant now = this.clock.instant();
207+
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
208+
209+
return now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew));
210+
}
211+
175212
/**
176213
* Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant.
177214
*

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java

+14-10
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ public OAuth2AuthorizedClientResolver(
7070
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
7171
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
7272
* resolved from the current Authentication.
73+
*
7374
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
7475
* Default is false.
7576
*/
@@ -80,6 +81,7 @@ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClie
8081
/**
8182
* If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is
8283
* recommended to be cautious with this feature since all HTTP requests will receive the access token.
84+
*
8385
* @param clientRegistrationId the id to use
8486
*/
8587
public void setDefaultClientRegistrationId(String clientRegistrationId) {
@@ -89,6 +91,7 @@ public void setDefaultClientRegistrationId(String clientRegistrationId) {
8991
/**
9092
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
9193
* client_credentials grant.
94+
*
9295
* @param clientCredentialsTokenResponseClient the client to use
9396
*/
9497
public void setClientCredentialsTokenResponseClient(
@@ -98,7 +101,7 @@ public void setClientCredentialsTokenResponseClient(
98101
}
99102

100103
Mono<Request> createDefaultedRequest(String clientRegistrationId,
101-
Authentication authentication, ServerWebExchange exchange) {
104+
Authentication authentication, ServerWebExchange exchange) {
102105
Mono<Authentication> defaultedAuthentication = Mono.justOrEmpty(authentication)
103106
.switchIfEmpty(currentAuthentication());
104107

@@ -124,14 +127,14 @@ Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {
124127

125128
private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
126129
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
127-
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
128-
.flatMap(clientRegistration -> {
129-
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
130-
return clientCredentials(clientRegistration, authentication, exchange);
131-
}
132-
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
133-
});
134-
}
130+
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
131+
.flatMap(clientRegistration -> {
132+
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
133+
return clientCredentials(clientRegistration, authentication, exchange);
134+
}
135+
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
136+
});
137+
}
135138

136139
Mono<OAuth2AuthorizedClient> clientCredentials(
137140
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
@@ -149,6 +152,7 @@ private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistratio
149152

150153
/**
151154
* Attempts to load the client registration id from the current {@link Authentication}
155+
*
152156
* @return
153157
*/
154158
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
@@ -176,7 +180,7 @@ static class Request {
176180
private final ServerWebExchange exchange;
177181

178182
public Request(String clientRegistrationId, Authentication authentication,
179-
ServerWebExchange exchange) {
183+
ServerWebExchange exchange) {
180184
this.clientRegistrationId = clientRegistrationId;
181185
this.authentication = authentication;
182186
this.exchange = exchange;

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientResolver.java

+53-11
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@
3131
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
3232
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
3333
import org.springframework.security.oauth2.core.AuthorizationGrantType;
34+
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
3435
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
3536
import org.springframework.util.Assert;
3637
import org.springframework.web.server.ServerWebExchange;
3738
import reactor.core.publisher.Mono;
3839

40+
import java.time.Clock;
41+
import java.time.Duration;
42+
import java.time.Instant;
3943
import java.util.Optional;
4044

4145
/**
@@ -57,6 +61,10 @@ class OAuth2AuthorizedClientResolver {
5761

5862
private String defaultClientRegistrationId;
5963

64+
private Clock clock = Clock.systemUTC();
65+
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
66+
67+
6068
public OAuth2AuthorizedClientResolver(
6169
ReactiveClientRegistrationRepository clientRegistrationRepository,
6270
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
@@ -70,6 +78,7 @@ public OAuth2AuthorizedClientResolver(
7078
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
7179
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
7280
* resolved from the current Authentication.
81+
*
7382
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
7483
* Default is false.
7584
*/
@@ -80,6 +89,7 @@ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClie
8089
/**
8190
* If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is
8291
* recommended to be cautious with this feature since all HTTP requests will receive the access token.
92+
*
8393
* @param clientRegistrationId the id to use
8494
*/
8595
public void setDefaultClientRegistrationId(String clientRegistrationId) {
@@ -89,6 +99,7 @@ public void setDefaultClientRegistrationId(String clientRegistrationId) {
8999
/**
90100
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
91101
* client_credentials grant.
102+
*
92103
* @param clientCredentialsTokenResponseClient the client to use
93104
*/
94105
public void setClientCredentialsTokenResponseClient(
@@ -98,7 +109,7 @@ public void setClientCredentialsTokenResponseClient(
98109
}
99110

100111
Mono<Request> createDefaultedRequest(String clientRegistrationId,
101-
Authentication authentication, ServerWebExchange exchange) {
112+
Authentication authentication, ServerWebExchange exchange) {
102113
Mono<Authentication> defaultedAuthentication = Mono.justOrEmpty(authentication)
103114
.switchIfEmpty(currentAuthentication());
104115

@@ -120,19 +131,27 @@ Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {
120131
Authentication authentication = request.getAuthentication();
121132
ServerWebExchange exchange = request.getExchange();
122133
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange)
123-
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange));
134+
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange))
135+
.flatMap(client -> {
136+
if (hasTokenExpired(client)) {
137+
return authorizedClientNotLoaded(clientRegistrationId, authentication, exchange);
138+
} else {
139+
return Mono.just(client);
140+
}
141+
});
142+
124143
}
125144

126145
private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
127146
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
128-
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
129-
.flatMap(clientRegistration -> {
130-
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
131-
return clientCredentials(clientRegistration, authentication, exchange);
132-
}
133-
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
134-
});
135-
}
147+
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
148+
.flatMap(clientRegistration -> {
149+
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
150+
return clientCredentials(clientRegistration, authentication, exchange);
151+
}
152+
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
153+
});
154+
}
136155

137156
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
138157
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
@@ -148,8 +167,31 @@ private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistratio
148167
.thenReturn(authorizedClient);
149168
}
150169

170+
private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) {
171+
if (this.authorizedClientRepository == null) {
172+
return false;
173+
}
174+
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
175+
if (refreshToken == null) {
176+
return false;
177+
}
178+
return hasTokenExpired(authorizedClient);
179+
}
180+
181+
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
182+
Instant now = this.clock.instant();
183+
if (authorizedClient.getAccessToken() == null) {
184+
return false; // Test scenario: authorizedClient has no accessToken
185+
} else {
186+
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
187+
188+
return now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew));
189+
}
190+
}
191+
151192
/**
152193
* Attempts to load the client registration id from the current {@link Authentication}
194+
*
153195
* @return
154196
*/
155197
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
@@ -177,7 +219,7 @@ static class Request {
177219
private final ServerWebExchange exchange;
178220

179221
public Request(String clientRegistrationId, Authentication authentication,
180-
ServerWebExchange exchange) {
222+
ServerWebExchange exchange) {
181223
this.clientRegistrationId = clientRegistrationId;
182224
this.authentication = authentication;
183225
this.exchange = exchange;

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java

+32-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import org.junit.After;
1919
import org.junit.Before;
2020
import org.junit.Test;
21+
import org.mockito.invocation.InvocationOnMock;
22+
import org.mockito.stubbing.Answer;
2123
import org.springframework.core.MethodParameter;
2224
import org.springframework.mock.web.MockHttpServletRequest;
2325
import org.springframework.security.authentication.TestingAuthenticationToken;
@@ -42,7 +44,9 @@
4244
import org.springframework.web.context.request.ServletWebRequest;
4345

4446
import javax.servlet.http.HttpServletRequest;
47+
import javax.servlet.http.HttpServletResponse;
4548
import java.lang.reflect.Method;
49+
import java.time.Instant;
4650

4751
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
4852
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
@@ -104,7 +108,8 @@ public void setup() {
104108
when(this.authorizedClientRepository.loadAuthorizedClient(
105109
eq(this.registration1.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class)))
106110
.thenReturn(this.authorizedClient1);
107-
this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class));
111+
this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class, withSettings()
112+
.name("expiresAt").defaultAnswer((Answer<Instant>) invocation -> Instant.now())));
108113
when(this.authorizedClientRepository.loadAuthorizedClient(
109114
eq(this.registration2.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class)))
110115
.thenReturn(this.authorizedClient2);
@@ -230,6 +235,32 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien
230235
eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null));
231236
}
232237

238+
@Test
239+
public void resolveArgumentClientCredentialsExpireReacquireToken() throws Exception {
240+
OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
241+
mock(OAuth2AccessTokenResponseClient.class);
242+
this.argumentResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);
243+
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
244+
.withToken("access-token-1234")
245+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
246+
.expiresIn(3600)
247+
.build();
248+
when(clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
249+
250+
MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", OAuth2AuthorizedClient.class);
251+
252+
OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument(
253+
methodParameter, null, new ServletWebRequest(this.request), null);
254+
255+
assertThat(authorizedClient).isNotNull();
256+
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration2);
257+
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName);
258+
assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken());
259+
260+
verify(this.authorizedClientRepository).saveAuthorizedClient(
261+
eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null));
262+
}
263+
233264
private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
234265
Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes);
235266
return new MethodParameter(method, 0);

0 commit comments

Comments
 (0)