Skip to content

Commit dcdeab5

Browse files
committed
DefaultReactiveOAuth2AuthorizedClientManager defaults ServerWebExchange
Fixes gh-7390
1 parent 96d44cd commit dcdeab5

File tree

2 files changed

+101
-61
lines changed

2 files changed

+101
-61
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java

+42-20
Original file line numberDiff line numberDiff line change
@@ -70,35 +70,52 @@ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRe
7070

7171
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
7272
Authentication principal = authorizeRequest.getPrincipal();
73-
7473
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
75-
Assert.notNull(serverWebExchange, "serverWebExchange cannot be null");
7674

7775
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
78-
.switchIfEmpty(Mono.defer(() ->
79-
this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
76+
.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
8077
.flatMap(authorizedClient -> {
8178
// Re-authorize
8279
return authorizationContext(authorizeRequest, authorizedClient)
8380
.flatMap(this.authorizedClientProvider::authorize)
84-
.doOnNext(reauthorizedClient ->
85-
this.authorizedClientRepository.saveAuthorizedClient(
86-
reauthorizedClient, principal, serverWebExchange))
81+
.flatMap(reauthorizedClient -> saveAuthorizedClient(reauthorizedClient, principal, serverWebExchange))
8782
// Default to the existing authorizedClient if the client was not re-authorized
8883
.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
8984
authorizeRequest.getAuthorizedClient() : authorizedClient);
9085
})
91-
.switchIfEmpty(Mono.defer(() ->
86+
.switchIfEmpty(Mono.deferWithContext(context ->
9287
// Authorize
9388
this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
9489
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
9590
"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
9691
.flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration))
9792
.flatMap(this.authorizedClientProvider::authorize)
98-
.doOnNext(authorizedClient ->
99-
this.authorizedClientRepository.saveAuthorizedClient(
100-
authorizedClient, principal, serverWebExchange))
101-
));
93+
.flatMap(authorizedClient -> saveAuthorizedClient(authorizedClient, principal, serverWebExchange))
94+
.subscriberContext(context)
95+
)
96+
);
97+
}
98+
99+
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) {
100+
return Mono.justOrEmpty(serverWebExchange)
101+
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
102+
.flatMap(exchange -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange));
103+
}
104+
105+
private Mono<OAuth2AuthorizedClient> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, ServerWebExchange serverWebExchange) {
106+
return Mono.justOrEmpty(serverWebExchange)
107+
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
108+
.map(exchange -> {
109+
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange);
110+
return authorizedClient;
111+
})
112+
.defaultIfEmpty(authorizedClient);
113+
}
114+
115+
private static Mono<ServerWebExchange> currentServerWebExchange() {
116+
return Mono.subscriberContext()
117+
.filter(c -> c.hasKey(ServerWebExchange.class))
118+
.map(c -> c.get(ServerWebExchange.class));
102119
}
103120

104121
private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
@@ -158,15 +175,20 @@ public static class DefaultContextAttributesMapper implements Function<OAuth2Aut
158175

159176
@Override
160177
public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
161-
Map<String, Object> contextAttributes = Collections.emptyMap();
162178
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
163-
String scope = serverWebExchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
164-
if (StringUtils.hasText(scope)) {
165-
contextAttributes = new HashMap<>();
166-
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
167-
StringUtils.delimitedListToStringArray(scope, " "));
168-
}
169-
return Mono.just(contextAttributes);
179+
return Mono.justOrEmpty(serverWebExchange)
180+
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
181+
.flatMap(exchange -> {
182+
Map<String, Object> contextAttributes = Collections.emptyMap();
183+
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
184+
if (StringUtils.hasText(scope)) {
185+
contextAttributes = new HashMap<>();
186+
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
187+
StringUtils.delimitedListToStringArray(scope, " "));
188+
}
189+
return Mono.just(contextAttributes);
190+
})
191+
.defaultIfEmpty(Collections.emptyMap());
170192
}
171193
}
172194
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java

+59-41
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
3535
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
3636
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
37-
import org.springframework.util.StringUtils;
3837
import org.springframework.web.server.ServerWebExchange;
3938
import reactor.core.publisher.Mono;
39+
import reactor.util.context.Context;
4040

4141
import java.util.Collections;
4242
import java.util.HashMap;
@@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
6464
private Authentication principal;
6565
private OAuth2AuthorizedClient authorizedClient;
6666
private MockServerWebExchange serverWebExchange;
67+
private Context context;
6768
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
6869

6970
@SuppressWarnings("unchecked")
@@ -75,6 +76,8 @@ public void setup() {
7576
this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class);
7677
when(this.authorizedClientRepository.loadAuthorizedClient(
7778
anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
79+
when(this.authorizedClientRepository.saveAuthorizedClient(
80+
any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
7881
this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
7982
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty());
8083
this.contextAttributesMapper = mock(Function.class);
@@ -88,6 +91,7 @@ public void setup() {
8891
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
8992
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
9093
this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
94+
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
9195
this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
9296
}
9397

@@ -119,16 +123,6 @@ public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException(
119123
.hasMessage("contextAttributesMapper cannot be null");
120124
}
121125

122-
@Test
123-
public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() {
124-
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
125-
.principal(this.principal)
126-
.build();
127-
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
128-
.isInstanceOf(IllegalArgumentException.class)
129-
.hasMessage("serverWebExchange cannot be null");
130-
}
131-
132126
@Test
133127
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
134128
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block())
@@ -140,9 +134,8 @@ public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
140134
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
141135
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id")
142136
.principal(this.principal)
143-
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
144137
.build();
145-
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
138+
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block())
146139
.isInstanceOf(IllegalArgumentException.class)
147140
.hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
148141
}
@@ -155,9 +148,9 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized()
155148

156149
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
157150
.principal(this.principal)
158-
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
159151
.build();
160-
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
152+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
153+
.subscriberContext(this.context).block();
161154

162155
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
163156
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@@ -168,24 +161,22 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized()
168161
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
169162

170163
assertThat(authorizedClient).isNull();
171-
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
172-
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
164+
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
173165
}
174166

175167
@SuppressWarnings("unchecked")
176168
@Test
177169
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
178170
when(this.clientRegistrationRepository.findByRegistrationId(
179171
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
180-
181172
when(this.authorizedClientProvider.authorize(
182173
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
183174

184175
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
185176
.principal(this.principal)
186-
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
187177
.build();
188-
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
178+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
179+
.subscriberContext(this.context).block();
189180

190181
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
191182
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@@ -200,6 +191,31 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
200191
eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange));
201192
}
202193

194+
@Test
195+
public void authorizeWhenNotAuthorizedAndSupportedProviderAndExchangeUnavailableThenAuthorizedButNotSaved() {
196+
when(this.clientRegistrationRepository.findByRegistrationId(
197+
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
198+
199+
when(this.authorizedClientProvider.authorize(
200+
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
201+
202+
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
203+
.principal(this.principal)
204+
.build();
205+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
206+
207+
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
208+
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
209+
210+
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
211+
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
212+
assertThat(authorizationContext.getAuthorizedClient()).isNull();
213+
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
214+
215+
assertThat(authorizedClient).isSameAs(this.authorizedClient);
216+
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
217+
}
218+
203219
@SuppressWarnings("unchecked")
204220
@Test
205221
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
@@ -216,9 +232,9 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
216232

217233
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
218234
.principal(this.principal)
219-
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
220235
.build();
221-
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
236+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
237+
.subscriberContext(this.context).block();
222238

223239
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
224240
verify(this.contextAttributesMapper).apply(any());
@@ -241,34 +257,31 @@ public void authorizeWhenRequestFormParameterUsernamePasswordThenMappedToContext
241257
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
242258

243259
// Set custom contextAttributesMapper capable of mapping the form parameters
244-
this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> {
245-
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
246-
return Mono.just(serverWebExchange)
260+
this.authorizedClientManager.setContextAttributesMapper(authorizeRequest ->
261+
currentServerWebExchange()
247262
.flatMap(ServerWebExchange::getFormData)
248263
.map(formData -> {
249264
Map<String, Object> contextAttributes = new HashMap<>();
250265
String username = formData.getFirst(OAuth2ParameterNames.USERNAME);
266+
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
251267
String password = formData.getFirst(OAuth2ParameterNames.PASSWORD);
252-
if (StringUtils.hasText(username) && StringUtils.hasText(password)) {
253-
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
254-
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
255-
}
268+
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
256269
return contextAttributes;
257-
});
258-
});
270+
})
271+
);
259272

260273
this.serverWebExchange = MockServerWebExchange.builder(
261274
MockServerHttpRequest
262275
.post("/")
263276
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
264277
.body("username=username&password=password"))
265278
.build();
279+
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
266280

267281
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
268282
.principal(this.principal)
269-
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
270283
.build();
271-
this.authorizedClientManager.authorize(authorizeRequest).block();
284+
this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block();
272285

273286
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
274287

@@ -284,9 +297,9 @@ public void authorizeWhenRequestFormParameterUsernamePasswordThenMappedToContext
284297
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
285298
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
286299
.principal(this.principal)
287-
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
288300
.build();
289-
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
301+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
302+
.subscriberContext(this.context).block();
290303

291304
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
292305
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@@ -297,8 +310,7 @@ public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
297310
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
298311

299312
assertThat(authorizedClient).isSameAs(this.authorizedClient);
300-
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
301-
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
313+
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
302314
}
303315

304316
@SuppressWarnings("unchecked")
@@ -312,9 +324,9 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() {
312324

313325
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
314326
.principal(this.principal)
315-
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
316327
.build();
317-
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
328+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
329+
.subscriberContext(this.context).block();
318330

319331
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
320332
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@@ -346,17 +358,23 @@ public void reauthorizeWhenRequestParameterScopeThenMappedToContext() {
346358
.get("/")
347359
.queryParam(OAuth2ParameterNames.SCOPE, "read write"))
348360
.build();
361+
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
349362

350363
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
351364
.principal(this.principal)
352-
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
353365
.build();
354-
this.authorizedClientManager.authorize(reauthorizeRequest).block();
366+
this.authorizedClientManager.authorize(reauthorizeRequest).subscriberContext(this.context).block();
355367

356368
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
357369

358370
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
359371
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
360372
assertThat(requestScopeAttribute).contains("read", "write");
361373
}
374+
375+
private Mono<ServerWebExchange> currentServerWebExchange() {
376+
return Mono.subscriberContext()
377+
.filter(c -> c.hasKey(ServerWebExchange.class))
378+
.map(c -> c.get(ServerWebExchange.class));
379+
}
362380
}

0 commit comments

Comments
 (0)