Skip to content

Commit 70df9b5

Browse files
committed
Fix blocking in ServletOAuth2AuthorizedClientExchangeFilterFunction
Fixes gh-6589
1 parent 7e84540 commit 70df9b5

File tree

5 files changed

+319
-79
lines changed

5 files changed

+319
-79
lines changed

gradle/dependency-management.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ dependencyManagement {
7070
dependency 'commons-lang:commons-lang:2.6'
7171
dependency 'commons-logging:commons-logging:1.2'
7272
dependency 'dom4j:dom4j:1.6.1'
73+
dependency 'io.projectreactor.tools:blockhound:1.0.0.M4'
7374
dependency 'javax.activation:activation:1.1.1'
7475
dependency 'javax.annotation:jsr250-api:1.0'
7576
dependency 'javax.inject:javax.inject:1'

oauth2/oauth2-client/spring-security-oauth2-client.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies {
1717
testCompile 'com.fasterxml.jackson.core:jackson-databind'
1818
testCompile 'io.projectreactor.netty:reactor-netty'
1919
testCompile 'io.projectreactor:reactor-test'
20+
testCompile 'io.projectreactor.tools:blockhound'
2021

2122
provided 'javax.servlet:javax.servlet-api'
2223
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

+36-19
Original file line numberDiff line numberDiff line change
@@ -288,20 +288,33 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
288288

289289
@Override
290290
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
291-
return Mono.just(request)
292-
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
293-
.switchIfEmpty(mergeRequestAttributesFromContext(request))
291+
return mergeRequestAttributesIfNecessary(request)
294292
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
295293
.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
294+
.switchIfEmpty(Mono.defer(() ->
295+
mergeRequestAttributesIfNecessary(request)
296+
.filter(req -> req.attribute(CLIENT_REGISTRATION_ID_ATTR_NAME).isPresent())
297+
.flatMap(this::authorizeClient)
298+
))
296299
.map(authorizedClient -> bearer(request, authorizedClient))
297300
.flatMap(next::exchange)
298-
.switchIfEmpty(next.exchange(request));
301+
.switchIfEmpty(Mono.defer(() -> next.exchange(request)));
302+
}
303+
304+
private Mono<ClientRequest> mergeRequestAttributesIfNecessary(ClientRequest request) {
305+
if (!request.attribute(HTTP_SERVLET_REQUEST_ATTR_NAME).isPresent() ||
306+
!request.attribute(HTTP_SERVLET_RESPONSE_ATTR_NAME).isPresent() ||
307+
!request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) {
308+
return mergeRequestAttributesFromContext(request);
309+
} else {
310+
return Mono.just(request);
311+
}
299312
}
300313

301314
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
302-
return Mono.just(ClientRequest.from(request))
303-
.flatMap(builder -> Mono.subscriberContext()
304-
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
315+
ClientRequest.Builder builder = ClientRequest.from(request);
316+
return Mono.subscriberContext()
317+
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx)))
305318
.map(ClientRequest.Builder::build);
306319
}
307320

@@ -348,35 +361,37 @@ private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
348361
return;
349362
}
350363

351-
Authentication authentication = getAuthentication(attrs);
352364
String clientRegistrationId = getClientRegistrationId(attrs);
353365
if (clientRegistrationId == null) {
354366
clientRegistrationId = this.defaultClientRegistrationId;
355367
}
368+
Authentication authentication = getAuthentication(attrs);
356369
if (clientRegistrationId == null
357370
&& this.defaultOAuth2AuthorizedClient
358371
&& authentication instanceof OAuth2AuthenticationToken) {
359372
clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
360373
}
361-
if (clientRegistrationId != null) {
362-
HttpServletRequest request = getRequest(attrs);
363-
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository
364-
.loadAuthorizedClient(clientRegistrationId, authentication,
365-
request);
366-
if (authorizedClient == null) {
367-
authorizedClient = getAuthorizedClient(clientRegistrationId, attrs);
374+
HttpServletRequest request = getRequest(attrs);
375+
if (clientRegistrationId != null && authentication != null && request != null) {
376+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
377+
clientRegistrationId, authentication, request);
378+
if (authorizedClient != null) {
379+
oauth2AuthorizedClient(authorizedClient).accept(attrs);
368380
}
369-
oauth2AuthorizedClient(authorizedClient).accept(attrs);
370381
}
371382
}
372383

373-
private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map<String, Object> attrs) {
384+
private Mono<OAuth2AuthorizedClient> authorizeClient(ClientRequest request) {
385+
Map<String, Object> attrs = request.attributes();
386+
String clientRegistrationId = getClientRegistrationId(attrs);
374387
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
375388
if (clientRegistration == null) {
376389
throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
377390
}
378391
if (isClientCredentialsGrantType(clientRegistration)) {
379-
return authorizeWithClientCredentials(clientRegistration, attrs);
392+
// NOTE: 'authorizeWithClientCredentials()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic())
393+
// since it performs a blocking I/O operation using RestTemplate internally
394+
return Mono.fromSupplier(() -> authorizeWithClientCredentials(clientRegistration, attrs)).subscribeOn(Schedulers.elastic());
380395
}
381396
throw new ClientAuthorizationRequiredException(clientRegistrationId);
382397
}
@@ -414,7 +429,9 @@ private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, Exc
414429
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
415430
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
416431
// Client credentials grant do not have refresh tokens but can expire so we need to get another one
417-
return Mono.fromSupplier(() -> authorizeWithClientCredentials(clientRegistration, request.attributes()));
432+
// NOTE: 'authorizeWithClientCredentials()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic())
433+
// since it performs a blocking I/O operation using RestTemplate internally
434+
return Mono.fromSupplier(() -> authorizeWithClientCredentials(clientRegistration, request.attributes())).subscribeOn(Schedulers.elastic());
418435
} else if (shouldRefreshToken(authorizedClient)) {
419436
return authorizeWithRefreshToken(request, next, authorizedClient);
420437
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.client.web.reactive.function.client;
17+
18+
import okhttp3.mockwebserver.MockResponse;
19+
import okhttp3.mockwebserver.MockWebServer;
20+
import org.junit.After;
21+
import org.junit.Before;
22+
import org.junit.BeforeClass;
23+
import org.junit.Test;
24+
import org.mockito.ArgumentCaptor;
25+
import org.springframework.http.HttpHeaders;
26+
import org.springframework.http.MediaType;
27+
import org.springframework.mock.web.MockHttpServletRequest;
28+
import org.springframework.mock.web.MockHttpServletResponse;
29+
import org.springframework.security.authentication.TestingAuthenticationToken;
30+
import org.springframework.security.core.Authentication;
31+
import org.springframework.security.core.context.SecurityContextHolder;
32+
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
33+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
34+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
35+
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
36+
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
37+
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
38+
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
39+
import org.springframework.security.oauth2.core.OAuth2AccessToken;
40+
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
41+
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
42+
import org.springframework.web.context.request.RequestContextHolder;
43+
import org.springframework.web.context.request.ServletRequestAttributes;
44+
import org.springframework.web.reactive.function.client.WebClient;
45+
import reactor.blockhound.BlockHound;
46+
47+
import javax.servlet.http.HttpServletRequest;
48+
import javax.servlet.http.HttpServletResponse;
49+
import java.time.Duration;
50+
import java.time.Instant;
51+
import java.util.Arrays;
52+
import java.util.HashSet;
53+
54+
import static org.assertj.core.api.Assertions.assertThat;
55+
import static org.mockito.ArgumentMatchers.eq;
56+
import static org.mockito.Mockito.*;
57+
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
58+
59+
/**
60+
* @author Joe Grandja
61+
*/
62+
public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
63+
private ClientRegistrationRepository clientRegistrationRepository;
64+
private OAuth2AuthorizedClientRepository authorizedClientRepository;
65+
private ServletOAuth2AuthorizedClientExchangeFilterFunction authorizedClientFilter;
66+
private MockWebServer server;
67+
private String serverUrl;
68+
private WebClient webClient;
69+
private Authentication authentication;
70+
private MockHttpServletRequest request;
71+
private MockHttpServletResponse response;
72+
73+
@BeforeClass
74+
public static void setUpBlockingChecks() {
75+
// IMPORTANT:
76+
// Before enabling BlockHound, we need to force the initialization of
77+
// java.lang.Package.defineSystemPackage(). When the JVM loads java.lang.Package.getSystemPackage(),
78+
// it attempts to java.lang.Package.loadManifest() which is blocking I/O and triggers BlockHound to error.
79+
// The following code forces the loading of the manifest.
80+
// NOTE: This is an issue with JDK 8. It's been tested on JDK 10 and works fine w/o this workaround.
81+
Class.class.getPackage();
82+
83+
BlockHound.install();
84+
}
85+
86+
@Before
87+
public void setUp() throws Exception {
88+
this.clientRegistrationRepository = mock(ClientRegistrationRepository.class);
89+
final OAuth2AuthorizedClientRepository delegate = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(
90+
new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository));
91+
this.authorizedClientRepository = spy(new OAuth2AuthorizedClientRepository() {
92+
@Override
93+
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request) {
94+
return delegate.loadAuthorizedClient(clientRegistrationId, principal, request);
95+
}
96+
97+
@Override
98+
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, HttpServletRequest request, HttpServletResponse response) {
99+
delegate.saveAuthorizedClient(authorizedClient, principal, request, response);
100+
}
101+
102+
@Override
103+
public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request, HttpServletResponse response) {
104+
delegate.removeAuthorizedClient(clientRegistrationId, principal, request, response);
105+
}
106+
});
107+
this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
108+
this.clientRegistrationRepository, this.authorizedClientRepository);
109+
this.authorizedClientFilter.afterPropertiesSet();
110+
this.server = new MockWebServer();
111+
this.server.start();
112+
this.serverUrl = this.server.url("/").toString();
113+
this.webClient = WebClient.builder()
114+
.apply(this.authorizedClientFilter.oauth2Configuration())
115+
.build();
116+
this.authentication = new TestingAuthenticationToken("principal", "password");
117+
SecurityContextHolder.getContext().setAuthentication(this.authentication);
118+
this.request = new MockHttpServletRequest();
119+
this.response = new MockHttpServletResponse();
120+
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response));
121+
}
122+
123+
@After
124+
public void cleanup() throws Exception {
125+
this.authorizedClientFilter.destroy();
126+
this.server.shutdown();
127+
SecurityContextHolder.clearContext();
128+
RequestContextHolder.resetRequestAttributes();
129+
}
130+
131+
@Test
132+
public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() {
133+
String accessTokenResponse = "{\n" +
134+
" \"access_token\": \"access-token-1234\",\n" +
135+
" \"token_type\": \"bearer\",\n" +
136+
" \"expires_in\": \"3600\",\n" +
137+
" \"scope\": \"read write\"\n" +
138+
"}\n";
139+
String clientResponse = "{\n" +
140+
" \"attribute1\": \"value1\",\n" +
141+
" \"attribute2\": \"value2\"\n" +
142+
"}\n";
143+
144+
this.server.enqueue(jsonResponse(accessTokenResponse));
145+
this.server.enqueue(jsonResponse(clientResponse));
146+
147+
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build();
148+
when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration);
149+
150+
this.webClient
151+
.get()
152+
.uri(this.serverUrl)
153+
.attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
154+
.retrieve()
155+
.bodyToMono(String.class)
156+
.block();
157+
158+
assertThat(this.server.getRequestCount()).isEqualTo(2);
159+
160+
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
161+
verify(this.authorizedClientRepository).saveAuthorizedClient(
162+
authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response));
163+
assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration);
164+
}
165+
166+
@Test
167+
public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() {
168+
String accessTokenResponse = "{\n" +
169+
" \"access_token\": \"refreshed-access-token\",\n" +
170+
" \"token_type\": \"bearer\",\n" +
171+
" \"expires_in\": \"3600\"\n" +
172+
"}\n";
173+
String clientResponse = "{\n" +
174+
" \"attribute1\": \"value1\",\n" +
175+
" \"attribute2\": \"value2\"\n" +
176+
"}\n";
177+
178+
this.server.enqueue(jsonResponse(accessTokenResponse));
179+
this.server.enqueue(jsonResponse(clientResponse));
180+
181+
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl).build();
182+
when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration);
183+
184+
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
185+
Instant expiresAt = issuedAt.plus(Duration.ofHours(1));
186+
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
187+
"expired-access-token", issuedAt, expiresAt, new HashSet<>(Arrays.asList("read", "write")));
188+
OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
189+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
190+
clientRegistration, this.authentication.getName(), accessToken, refreshToken);
191+
doReturn(authorizedClient).when(this.authorizedClientRepository).loadAuthorizedClient(
192+
eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.request));
193+
194+
this.webClient
195+
.get()
196+
.uri(this.serverUrl)
197+
.attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
198+
.retrieve()
199+
.bodyToMono(String.class)
200+
.block();
201+
202+
assertThat(this.server.getRequestCount()).isEqualTo(2);
203+
204+
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
205+
verify(this.authorizedClientRepository).saveAuthorizedClient(
206+
authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response));
207+
OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue();
208+
assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration);
209+
assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token");
210+
}
211+
212+
@Test
213+
public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() {
214+
String accessTokenResponse = "{\n" +
215+
" \"access_token\": \"access-token-1234\",\n" +
216+
" \"token_type\": \"bearer\",\n" +
217+
" \"expires_in\": \"3600\",\n" +
218+
" \"scope\": \"read write\"\n" +
219+
"}\n";
220+
String clientResponse = "{\n" +
221+
" \"attribute1\": \"value1\",\n" +
222+
" \"attribute2\": \"value2\"\n" +
223+
"}\n";
224+
225+
// Client 1
226+
this.server.enqueue(jsonResponse(accessTokenResponse));
227+
this.server.enqueue(jsonResponse(clientResponse));
228+
229+
ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials()
230+
.registrationId("client-1").tokenUri(this.serverUrl).build();
231+
when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))).thenReturn(clientRegistration1);
232+
233+
// Client 2
234+
this.server.enqueue(jsonResponse(accessTokenResponse));
235+
this.server.enqueue(jsonResponse(clientResponse));
236+
237+
ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials()
238+
.registrationId("client-2").tokenUri(this.serverUrl).build();
239+
when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))).thenReturn(clientRegistration2);
240+
241+
this.webClient
242+
.get()
243+
.uri(this.serverUrl)
244+
.attributes(clientRegistrationId(clientRegistration1.getRegistrationId()))
245+
.retrieve()
246+
.bodyToMono(String.class)
247+
.flatMap(response -> this.webClient
248+
.get()
249+
.uri(this.serverUrl)
250+
.attributes(clientRegistrationId(clientRegistration2.getRegistrationId()))
251+
.retrieve()
252+
.bodyToMono(String.class))
253+
.block();
254+
255+
assertThat(this.server.getRequestCount()).isEqualTo(4);
256+
257+
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
258+
verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient(
259+
authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response));
260+
assertThat(authorizedClientCaptor.getAllValues().get(0).getClientRegistration()).isSameAs(clientRegistration1);
261+
assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2);
262+
}
263+
264+
private MockResponse jsonResponse(String json) {
265+
return new MockResponse()
266+
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
267+
.setBody(json);
268+
}
269+
}

0 commit comments

Comments
 (0)