Skip to content

Commit 1c9ab91

Browse files
Warren Baileyjgrandja
Warren Bailey
authored andcommitted
When expired retrieve new Client Credentials token.
Once client credentials access token has expired retrieve a new token from the OAuth2 authorization server. These tokens can't be refreshed because they do not have a refresh token associated with. This is standard behaviour for Oauth 2 client credentails Fixes gh-5893
1 parent 9b65107 commit 1c9ab91

File tree

5 files changed

+204
-5
lines changed

5 files changed

+204
-5
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegi
133133
});
134134
}
135135

136-
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
136+
Mono<OAuth2AuthorizedClient> clientCredentials(
137137
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
138138
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
139139
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

+27-2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
8484
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
8585

8686
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
87+
this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository));
88+
}
89+
90+
ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) {
8791
this.authorizedClientRepository = authorizedClientRepository;
88-
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
92+
this.authorizedClientResolver = authorizedClientResolver;
8993
}
9094

9195
/**
@@ -245,13 +249,30 @@ private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ClientRequest
245249
}
246250

247251
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
248-
if (shouldRefresh(authorizedClient)) {
252+
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
253+
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
254+
return createRequest(request)
255+
.flatMap(r -> authorizeWithClientCredentials(clientRegistration, r));
256+
} else if (shouldRefresh(authorizedClient)) {
249257
return createRequest(request)
250258
.flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
251259
}
252260
return Mono.just(authorizedClient);
253261
}
254262

263+
private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
264+
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
265+
}
266+
267+
private Mono<OAuth2AuthorizedClient> authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) {
268+
Authentication authentication = request.getAuthentication();
269+
ServerWebExchange exchange = request.getExchange();
270+
271+
return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange).
272+
flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
273+
.thenReturn(result));
274+
}
275+
255276
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
256277
OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
257278
ServerWebExchange exchange = r.getExchange();
@@ -280,6 +301,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
280301
if (refreshToken == null) {
281302
return false;
282303
}
304+
return hasTokenExpired(authorizedClient);
305+
}
306+
307+
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
283308
Instant now = this.clock.instant();
284309
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
285310
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,16 @@ private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId,
332332
if (clientRegistration == null) {
333333
throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
334334
}
335-
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
335+
if (isClientCredentialsGrantType(clientRegistration)) {
336336
return getAuthorizedClient(clientRegistration, attrs);
337337
}
338338
throw new ClientAuthorizationRequiredException(clientRegistrationId);
339339
}
340340

341+
private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
342+
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
343+
}
344+
341345

342346
private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration,
343347
Map<String, Object> attrs) {
@@ -366,7 +370,11 @@ private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegi
366370
}
367371

368372
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
369-
if (shouldRefresh(authorizedClient)) {
373+
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
374+
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
375+
//Client credentials grant do not have refresh tokens but can expire so we need to get another one
376+
return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, request.attributes()));
377+
} else if (shouldRefresh(authorizedClient)) {
370378
return refreshAuthorizedClient(request, next, authorizedClient);
371379
}
372380
return Mono.just(authorizedClient);
@@ -407,6 +415,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
407415
if (refreshToken == null) {
408416
return false;
409417
}
418+
return hasTokenExpired(authorizedClient);
419+
}
420+
421+
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
410422
Instant now = this.clock.instant();
411423
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
412424
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

+87
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.security.oauth2.client.registration.ClientRegistration;
4343
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
4444
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
45+
import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request;
4546
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
4647
import org.springframework.security.oauth2.core.OAuth2AccessToken;
4748
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@@ -67,6 +68,7 @@
6768
import static org.assertj.core.api.Assertions.assertThat;
6869
import static org.mockito.ArgumentMatchers.any;
6970
import static org.mockito.ArgumentMatchers.eq;
71+
import static org.mockito.Mockito.never;
7072
import static org.mockito.Mockito.verify;
7173
import static org.mockito.Mockito.verifyZeroInteractions;
7274
import static org.mockito.Mockito.when;
@@ -86,6 +88,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
8688
@Mock
8789
private ReactiveClientRegistrationRepository clientRegistrationRepository;
8890

91+
@Mock
92+
private OAuth2AuthorizedClientResolver oAuth2AuthorizedClientResolver;
93+
8994
@Mock
9095
private ServerWebExchange serverWebExchange;
9196

@@ -144,6 +149,88 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
144149
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
145150
}
146151

152+
@Test
153+
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
154+
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
155+
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
156+
String clientRegistrationId = registration.getClientId();
157+
158+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);
159+
160+
OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
161+
"new-token",
162+
Instant.now(),
163+
Instant.now().plus(Duration.ofDays(1)));
164+
OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration,
165+
"principalName", newAccessToken, null);
166+
Request r = new Request(clientRegistrationId, authentication, null);
167+
when(this.oAuth2AuthorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient));
168+
when(this.oAuth2AuthorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r));
169+
170+
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
171+
172+
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
173+
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
174+
175+
OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
176+
this.accessToken.getTokenValue(),
177+
issuedAt,
178+
accessTokenExpiresAt);
179+
180+
181+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
182+
"principalName", accessToken, null);
183+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
184+
.attributes(oauth2AuthorizedClient(authorizedClient))
185+
.build();
186+
187+
188+
this.function.filter(request, this.exchange)
189+
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
190+
.block();
191+
192+
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
193+
verify(this.oAuth2AuthorizedClientResolver).clientCredentials(any(), any(), any());
194+
verify(this.oAuth2AuthorizedClientResolver).createDefaultedRequest(any(), any(), any());
195+
196+
List<ClientRequest> requests = this.exchange.getRequests();
197+
assertThat(requests).hasSize(1);
198+
ClientRequest request1 = requests.get(0);
199+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
200+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
201+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
202+
assertThat(getBody(request1)).isEmpty();
203+
}
204+
205+
@Test
206+
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
207+
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
208+
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
209+
210+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);
211+
212+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
213+
"principalName", this.accessToken, null);
214+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
215+
.attributes(oauth2AuthorizedClient(authorizedClient))
216+
.build();
217+
218+
this.function.filter(request, this.exchange)
219+
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
220+
.block();
221+
222+
verify(this.oAuth2AuthorizedClientResolver, never()).clientCredentials(any(), any(), any());
223+
verify(this.oAuth2AuthorizedClientResolver, never()).createDefaultedRequest(any(), any(), any());
224+
225+
List<ClientRequest> requests = this.exchange.getRequests();
226+
assertThat(requests).hasSize(1);
227+
ClientRequest request1 = requests.get(0);
228+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
229+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
230+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
231+
assertThat(getBody(request1)).isEmpty();
232+
}
233+
147234
@Test
148235
public void filterWhenRefreshRequiredThenRefresh() {
149236
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

+75
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import static org.mockito.ArgumentMatchers.any;
7979
import static org.mockito.ArgumentMatchers.eq;
8080
import static org.mockito.Mockito.mock;
81+
import static org.mockito.Mockito.never;
8182
import static org.mockito.Mockito.verify;
8283
import static org.mockito.Mockito.verifyZeroInteractions;
8384
import static org.mockito.Mockito.when;
@@ -423,6 +424,80 @@ public void filterWhenRefreshRequiredThenRefresh() {
423424
assertThat(getBody(request1)).isEmpty();
424425
}
425426

427+
@Test
428+
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
429+
this.registration = TestClientRegistrations.clientCredentials().build();
430+
431+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
432+
this.authorizedClientRepository);
433+
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
434+
435+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
436+
"principalName", this.accessToken, null);
437+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
438+
.attributes(oauth2AuthorizedClient(authorizedClient))
439+
.attributes(authentication(this.authentication))
440+
.build();
441+
442+
this.function.filter(request, this.exchange).block();
443+
444+
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
445+
446+
verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any());
447+
448+
List<ClientRequest> requests = this.exchange.getRequests();
449+
assertThat(requests).hasSize(1);
450+
451+
ClientRequest request1 = requests.get(0);
452+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
453+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
454+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
455+
assertThat(getBody(request1)).isEmpty();
456+
}
457+
458+
@Test
459+
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
460+
this.registration = TestClientRegistrations.clientCredentials().build();
461+
462+
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
463+
.accessTokenResponse().build();
464+
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
465+
accessTokenResponse);
466+
467+
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
468+
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
469+
470+
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
471+
this.accessToken.getTokenValue(),
472+
issuedAt,
473+
accessTokenExpiresAt);
474+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
475+
this.authorizedClientRepository);
476+
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
477+
478+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
479+
"principalName", this.accessToken, null);
480+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
481+
.attributes(oauth2AuthorizedClient(authorizedClient))
482+
.attributes(authentication(this.authentication))
483+
.build();
484+
485+
this.function.filter(request, this.exchange).block();
486+
487+
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
488+
489+
verify(clientCredentialsTokenResponseClient).getTokenResponse(any());
490+
491+
List<ClientRequest> requests = this.exchange.getRequests();
492+
assertThat(requests).hasSize(1);
493+
494+
ClientRequest request1 = requests.get(0);
495+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token");
496+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
497+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
498+
assertThat(getBody(request1)).isEmpty();
499+
}
500+
426501
@Test
427502
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
428503
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")

0 commit comments

Comments
 (0)