34
34
import org .springframework .security .oauth2 .core .TestOAuth2AccessTokens ;
35
35
import org .springframework .security .oauth2 .core .TestOAuth2RefreshTokens ;
36
36
import org .springframework .security .oauth2 .core .endpoint .OAuth2ParameterNames ;
37
- import org .springframework .util .StringUtils ;
38
37
import org .springframework .web .server .ServerWebExchange ;
39
38
import reactor .core .publisher .Mono ;
39
+ import reactor .util .context .Context ;
40
40
41
41
import java .util .Collections ;
42
42
import java .util .HashMap ;
@@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
64
64
private Authentication principal ;
65
65
private OAuth2AuthorizedClient authorizedClient ;
66
66
private MockServerWebExchange serverWebExchange ;
67
+ private Context context ;
67
68
private ArgumentCaptor <OAuth2AuthorizationContext > authorizationContextCaptor ;
68
69
69
70
@ SuppressWarnings ("unchecked" )
@@ -75,6 +76,8 @@ public void setup() {
75
76
this .authorizedClientRepository = mock (ServerOAuth2AuthorizedClientRepository .class );
76
77
when (this .authorizedClientRepository .loadAuthorizedClient (
77
78
anyString (), any (Authentication .class ), any (ServerWebExchange .class ))).thenReturn (Mono .empty ());
79
+ when (this .authorizedClientRepository .saveAuthorizedClient (
80
+ any (OAuth2AuthorizedClient .class ), any (Authentication .class ), any (ServerWebExchange .class ))).thenReturn (Mono .empty ());
78
81
this .authorizedClientProvider = mock (ReactiveOAuth2AuthorizedClientProvider .class );
79
82
when (this .authorizedClientProvider .authorize (any (OAuth2AuthorizationContext .class ))).thenReturn (Mono .empty ());
80
83
this .contextAttributesMapper = mock (Function .class );
@@ -88,6 +91,7 @@ public void setup() {
88
91
this .authorizedClient = new OAuth2AuthorizedClient (this .clientRegistration , this .principal .getName (),
89
92
TestOAuth2AccessTokens .scopes ("read" , "write" ), TestOAuth2RefreshTokens .refreshToken ());
90
93
this .serverWebExchange = MockServerWebExchange .builder (MockServerHttpRequest .get ("/" )).build ();
94
+ this .context = Context .of (ServerWebExchange .class , this .serverWebExchange );
91
95
this .authorizationContextCaptor = ArgumentCaptor .forClass (OAuth2AuthorizationContext .class );
92
96
}
93
97
@@ -119,16 +123,6 @@ public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException(
119
123
.hasMessage ("contextAttributesMapper cannot be null" );
120
124
}
121
125
122
- @ Test
123
- public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException () {
124
- OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest .withClientRegistrationId (this .clientRegistration .getRegistrationId ())
125
- .principal (this .principal )
126
- .build ();
127
- assertThatThrownBy (() -> this .authorizedClientManager .authorize (authorizeRequest ).block ())
128
- .isInstanceOf (IllegalArgumentException .class )
129
- .hasMessage ("serverWebExchange cannot be null" );
130
- }
131
-
132
126
@ Test
133
127
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException () {
134
128
assertThatThrownBy (() -> this .authorizedClientManager .authorize (null ).block ())
@@ -140,9 +134,8 @@ public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
140
134
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException () {
141
135
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest .withClientRegistrationId ("invalid-registration-id" )
142
136
.principal (this .principal )
143
- .attribute (ServerWebExchange .class .getName (), this .serverWebExchange )
144
137
.build ();
145
- assertThatThrownBy (() -> this .authorizedClientManager .authorize (authorizeRequest ).block ())
138
+ assertThatThrownBy (() -> this .authorizedClientManager .authorize (authorizeRequest ).subscriberContext ( this . context ). block ())
146
139
.isInstanceOf (IllegalArgumentException .class )
147
140
.hasMessage ("Could not find ClientRegistration with id 'invalid-registration-id'" );
148
141
}
@@ -155,9 +148,9 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized()
155
148
156
149
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest .withClientRegistrationId (this .clientRegistration .getRegistrationId ())
157
150
.principal (this .principal )
158
- .attribute (ServerWebExchange .class .getName (), this .serverWebExchange )
159
151
.build ();
160
- OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (authorizeRequest ).block ();
152
+ OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (authorizeRequest )
153
+ .subscriberContext (this .context ).block ();
161
154
162
155
verify (this .authorizedClientProvider ).authorize (this .authorizationContextCaptor .capture ());
163
156
verify (this .contextAttributesMapper ).apply (eq (authorizeRequest ));
@@ -168,24 +161,22 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized()
168
161
assertThat (authorizationContext .getPrincipal ()).isEqualTo (this .principal );
169
162
170
163
assertThat (authorizedClient ).isNull ();
171
- verify (this .authorizedClientRepository , never ()).saveAuthorizedClient (
172
- any (OAuth2AuthorizedClient .class ), eq (this .principal ), eq (this .serverWebExchange ));
164
+ verify (this .authorizedClientRepository , never ()).saveAuthorizedClient (any (), any (), any ());
173
165
}
174
166
175
167
@ SuppressWarnings ("unchecked" )
176
168
@ Test
177
169
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized () {
178
170
when (this .clientRegistrationRepository .findByRegistrationId (
179
171
eq (this .clientRegistration .getRegistrationId ()))).thenReturn (Mono .just (this .clientRegistration ));
180
-
181
172
when (this .authorizedClientProvider .authorize (
182
173
any (OAuth2AuthorizationContext .class ))).thenReturn (Mono .just (this .authorizedClient ));
183
174
184
175
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest .withClientRegistrationId (this .clientRegistration .getRegistrationId ())
185
176
.principal (this .principal )
186
- .attribute (ServerWebExchange .class .getName (), this .serverWebExchange )
187
177
.build ();
188
- OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (authorizeRequest ).block ();
178
+ OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (authorizeRequest )
179
+ .subscriberContext (this .context ).block ();
189
180
190
181
verify (this .authorizedClientProvider ).authorize (this .authorizationContextCaptor .capture ());
191
182
verify (this .contextAttributesMapper ).apply (eq (authorizeRequest ));
@@ -200,6 +191,31 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
200
191
eq (this .authorizedClient ), eq (this .principal ), eq (this .serverWebExchange ));
201
192
}
202
193
194
+ @ Test
195
+ public void authorizeWhenNotAuthorizedAndSupportedProviderAndExchangeUnavailableThenAuthorizedButNotSaved () {
196
+ when (this .clientRegistrationRepository .findByRegistrationId (
197
+ eq (this .clientRegistration .getRegistrationId ()))).thenReturn (Mono .just (this .clientRegistration ));
198
+
199
+ when (this .authorizedClientProvider .authorize (
200
+ any (OAuth2AuthorizationContext .class ))).thenReturn (Mono .just (this .authorizedClient ));
201
+
202
+ OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest .withClientRegistrationId (this .clientRegistration .getRegistrationId ())
203
+ .principal (this .principal )
204
+ .build ();
205
+ OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (authorizeRequest ).block ();
206
+
207
+ verify (this .authorizedClientProvider ).authorize (this .authorizationContextCaptor .capture ());
208
+ verify (this .contextAttributesMapper ).apply (eq (authorizeRequest ));
209
+
210
+ OAuth2AuthorizationContext authorizationContext = this .authorizationContextCaptor .getValue ();
211
+ assertThat (authorizationContext .getClientRegistration ()).isEqualTo (this .clientRegistration );
212
+ assertThat (authorizationContext .getAuthorizedClient ()).isNull ();
213
+ assertThat (authorizationContext .getPrincipal ()).isEqualTo (this .principal );
214
+
215
+ assertThat (authorizedClient ).isSameAs (this .authorizedClient );
216
+ verify (this .authorizedClientRepository , never ()).saveAuthorizedClient (any (), any (), any ());
217
+ }
218
+
203
219
@ SuppressWarnings ("unchecked" )
204
220
@ Test
205
221
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized () {
@@ -216,9 +232,9 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
216
232
217
233
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest .withClientRegistrationId (this .clientRegistration .getRegistrationId ())
218
234
.principal (this .principal )
219
- .attribute (ServerWebExchange .class .getName (), this .serverWebExchange )
220
235
.build ();
221
- OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (authorizeRequest ).block ();
236
+ OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (authorizeRequest )
237
+ .subscriberContext (this .context ).block ();
222
238
223
239
verify (this .authorizedClientProvider ).authorize (this .authorizationContextCaptor .capture ());
224
240
verify (this .contextAttributesMapper ).apply (any ());
@@ -241,34 +257,31 @@ public void authorizeWhenRequestFormParameterUsernamePasswordThenMappedToContext
241
257
when (this .authorizedClientProvider .authorize (any (OAuth2AuthorizationContext .class ))).thenReturn (Mono .just (this .authorizedClient ));
242
258
243
259
// Set custom contextAttributesMapper capable of mapping the form parameters
244
- this .authorizedClientManager .setContextAttributesMapper (authorizeRequest -> {
245
- ServerWebExchange serverWebExchange = authorizeRequest .getAttribute (ServerWebExchange .class .getName ());
246
- return Mono .just (serverWebExchange )
260
+ this .authorizedClientManager .setContextAttributesMapper (authorizeRequest ->
261
+ currentServerWebExchange ()
247
262
.flatMap (ServerWebExchange ::getFormData )
248
263
.map (formData -> {
249
264
Map <String , Object > contextAttributes = new HashMap <>();
250
265
String username = formData .getFirst (OAuth2ParameterNames .USERNAME );
266
+ contextAttributes .put (OAuth2AuthorizationContext .USERNAME_ATTRIBUTE_NAME , username );
251
267
String password = formData .getFirst (OAuth2ParameterNames .PASSWORD );
252
- if (StringUtils .hasText (username ) && StringUtils .hasText (password )) {
253
- contextAttributes .put (OAuth2AuthorizationContext .USERNAME_ATTRIBUTE_NAME , username );
254
- contextAttributes .put (OAuth2AuthorizationContext .PASSWORD_ATTRIBUTE_NAME , password );
255
- }
268
+ contextAttributes .put (OAuth2AuthorizationContext .PASSWORD_ATTRIBUTE_NAME , password );
256
269
return contextAttributes ;
257
- });
258
- } );
270
+ })
271
+ );
259
272
260
273
this .serverWebExchange = MockServerWebExchange .builder (
261
274
MockServerHttpRequest
262
275
.post ("/" )
263
276
.contentType (MediaType .APPLICATION_FORM_URLENCODED )
264
277
.body ("username=username&password=password" ))
265
278
.build ();
279
+ this .context = Context .of (ServerWebExchange .class , this .serverWebExchange );
266
280
267
281
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest .withClientRegistrationId (this .clientRegistration .getRegistrationId ())
268
282
.principal (this .principal )
269
- .attribute (ServerWebExchange .class .getName (), this .serverWebExchange )
270
283
.build ();
271
- this .authorizedClientManager .authorize (authorizeRequest ).block ();
284
+ this .authorizedClientManager .authorize (authorizeRequest ).subscriberContext ( this . context ). block ();
272
285
273
286
verify (this .authorizedClientProvider ).authorize (this .authorizationContextCaptor .capture ());
274
287
@@ -284,9 +297,9 @@ public void authorizeWhenRequestFormParameterUsernamePasswordThenMappedToContext
284
297
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized () {
285
298
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest .withAuthorizedClient (this .authorizedClient )
286
299
.principal (this .principal )
287
- .attribute (ServerWebExchange .class .getName (), this .serverWebExchange )
288
300
.build ();
289
- OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (reauthorizeRequest ).block ();
301
+ OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (reauthorizeRequest )
302
+ .subscriberContext (this .context ).block ();
290
303
291
304
verify (this .authorizedClientProvider ).authorize (this .authorizationContextCaptor .capture ());
292
305
verify (this .contextAttributesMapper ).apply (eq (reauthorizeRequest ));
@@ -297,8 +310,7 @@ public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
297
310
assertThat (authorizationContext .getPrincipal ()).isEqualTo (this .principal );
298
311
299
312
assertThat (authorizedClient ).isSameAs (this .authorizedClient );
300
- verify (this .authorizedClientRepository , never ()).saveAuthorizedClient (
301
- any (OAuth2AuthorizedClient .class ), eq (this .principal ), eq (this .serverWebExchange ));
313
+ verify (this .authorizedClientRepository , never ()).saveAuthorizedClient (any (), any (), any ());
302
314
}
303
315
304
316
@ SuppressWarnings ("unchecked" )
@@ -312,9 +324,9 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() {
312
324
313
325
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest .withAuthorizedClient (this .authorizedClient )
314
326
.principal (this .principal )
315
- .attribute (ServerWebExchange .class .getName (), this .serverWebExchange )
316
327
.build ();
317
- OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (reauthorizeRequest ).block ();
328
+ OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (reauthorizeRequest )
329
+ .subscriberContext (this .context ).block ();
318
330
319
331
verify (this .authorizedClientProvider ).authorize (this .authorizationContextCaptor .capture ());
320
332
verify (this .contextAttributesMapper ).apply (eq (reauthorizeRequest ));
@@ -346,17 +358,23 @@ public void reauthorizeWhenRequestParameterScopeThenMappedToContext() {
346
358
.get ("/" )
347
359
.queryParam (OAuth2ParameterNames .SCOPE , "read write" ))
348
360
.build ();
361
+ this .context = Context .of (ServerWebExchange .class , this .serverWebExchange );
349
362
350
363
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest .withAuthorizedClient (this .authorizedClient )
351
364
.principal (this .principal )
352
- .attribute (ServerWebExchange .class .getName (), this .serverWebExchange )
353
365
.build ();
354
- this .authorizedClientManager .authorize (reauthorizeRequest ).block ();
366
+ this .authorizedClientManager .authorize (reauthorizeRequest ).subscriberContext ( this . context ). block ();
355
367
356
368
verify (this .authorizedClientProvider ).authorize (this .authorizationContextCaptor .capture ());
357
369
358
370
OAuth2AuthorizationContext authorizationContext = this .authorizationContextCaptor .getValue ();
359
371
String [] requestScopeAttribute = authorizationContext .getAttribute (OAuth2AuthorizationContext .REQUEST_SCOPE_ATTRIBUTE_NAME );
360
372
assertThat (requestScopeAttribute ).contains ("read" , "write" );
361
373
}
374
+
375
+ private Mono <ServerWebExchange > currentServerWebExchange () {
376
+ return Mono .subscriberContext ()
377
+ .filter (c -> c .hasKey (ServerWebExchange .class ))
378
+ .map (c -> c .get (ServerWebExchange .class ));
379
+ }
362
380
}
0 commit comments