Skip to content

Commit 9bae0a4

Browse files
rchigvintsevjgrandja
authored andcommitted
Allow to customize OAuth2AuthorizationRequestRedirectWebFilter in OAuth2LoginSpec
Fixes gh-7466
1 parent 2a5bd6e commit 9bae0a4

File tree

2 files changed

+108
-3
lines changed

2 files changed

+108
-3
lines changed

config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

+29
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
7777
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
7878
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
79+
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
7980
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
8081
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
8182
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
@@ -972,6 +973,8 @@ public class OAuth2LoginSpec {
972973

973974
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
974975

976+
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
977+
975978
private ReactiveAuthenticationManager authenticationManager;
976979

977980
private ServerSecurityContextRepository securityContextRepository;
@@ -1102,6 +1105,18 @@ public OAuth2LoginSpec authorizedClientRepository(ServerOAuth2AuthorizedClientRe
11021105
return this;
11031106
}
11041107

1108+
/**
1109+
* Sets authorization request repository for {@link OAuth2AuthorizationRequestRedirectWebFilter}.
1110+
*
1111+
* @param authorizationRequestRepository authorization request repository, must not be null
1112+
* @return the {@link OAuth2LoginSpec} for further configuration
1113+
*/
1114+
public OAuth2LoginSpec authorizationRequestRepository(ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository) {
1115+
Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null");
1116+
this.authorizationRequestRepository = authorizationRequestRepository;
1117+
return this;
1118+
}
1119+
11051120
/**
11061121
* Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
11071122
*
@@ -1146,6 +1161,12 @@ protected void configure(ServerHttpSecurity http) {
11461161
ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
11471162
ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
11481163
OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter();
1164+
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
1165+
getAuthorizationRequestRepository();
1166+
if (authorizationRequestRepository != null) {
1167+
oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository);
1168+
}
1169+
oauthRedirectFilter.setRequestCache(http.requestCache.requestCache);
11491170

11501171
ReactiveAuthenticationManager manager = getAuthenticationManager();
11511172

@@ -1246,6 +1267,14 @@ private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() {
12461267
return result;
12471268
}
12481269

1270+
@SuppressWarnings("unchecked")
1271+
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> getAuthorizationRequestRepository() {
1272+
if (this.authorizationRequestRepository == null) {
1273+
this.authorizationRequestRepository = getBeanOrNull(ServerAuthorizationRequestRepository.class);
1274+
}
1275+
return this.authorizationRequestRepository;
1276+
}
1277+
12491278
private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() {
12501279
ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class);
12511280
if (service == null) {

config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

+79-3
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
import static org.mockito.BDDMockito.given;
2121
import static org.mockito.ArgumentMatchers.any;
2222
import static org.mockito.Mockito.mock;
23+
import static org.mockito.Mockito.times;
2324
import static org.mockito.Mockito.verify;
2425
import static org.mockito.Mockito.verifyZeroInteractions;
2526
import static org.mockito.Mockito.when;
2627
import static org.springframework.security.config.Customizer.withDefaults;
2728

2829
import java.util.Arrays;
30+
import java.util.Collections;
2931
import java.util.List;
3032
import java.util.Objects;
3133
import java.util.Optional;
@@ -41,25 +43,37 @@
4143
import org.springframework.security.core.Authentication;
4244
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
4345
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
46+
import org.springframework.web.server.handler.FilteringWebHandler;
4447
import reactor.core.publisher.Mono;
4548
import reactor.test.publisher.TestPublisher;
4649

4750
import org.springframework.security.authentication.ReactiveAuthenticationManager;
4851
import org.springframework.security.authentication.TestingAuthenticationToken;
4952
import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
5053
import org.springframework.security.core.context.SecurityContext;
54+
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
55+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
56+
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
57+
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
58+
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
59+
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
60+
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
5161
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
5262
import org.springframework.security.web.server.SecurityWebFilterChain;
5363
import org.springframework.security.web.server.WebFilterChainProxy;
64+
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
65+
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
5466
import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
5567
import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
5668
import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
5769
import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
70+
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
5871
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
5972
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
6073
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
6174
import org.springframework.security.web.server.csrf.CsrfWebFilter;
6275
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
76+
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
6377
import org.springframework.test.util.ReflectionTestUtils;
6478
import org.springframework.test.web.reactive.server.EntityExchangeResult;
6579
import org.springframework.test.web.reactive.server.FluxExchangeResult;
@@ -68,10 +82,7 @@
6882
import org.springframework.web.bind.annotation.RestController;
6983
import org.springframework.web.server.ServerWebExchange;
7084
import org.springframework.web.server.WebFilter;
71-
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
7285
import org.springframework.web.server.WebFilterChain;
73-
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
74-
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
7586

7687
/**
7788
* @author Rob Winch
@@ -475,6 +486,71 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
475486
verify(customServerCsrfTokenRepository).loadToken(any());
476487
}
477488

489+
@SuppressWarnings("UnassignedFluxMonoInstance")
490+
@Test
491+
public void configureOAuth2LoginUsingCustomCommonServerRequestCache() {
492+
ServerRequestCache requestCacheMock = mock(ServerRequestCache.class);
493+
when(requestCacheMock.saveRequest(any(ServerWebExchange.class))).thenReturn(Mono.empty());
494+
495+
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
496+
String registrationId = clientRegistration.getRegistrationId();
497+
498+
ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
499+
mock(ReactiveClientRegistrationRepository.class);
500+
when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
501+
.thenReturn(Mono.just(clientRegistration));
502+
503+
SecurityWebFilterChain filterChain = http.requestCache().requestCache(requestCacheMock)
504+
.and().oauth2Login().clientRegistrationRepository(clientRegistrationRepositoryMock)
505+
.and().build();
506+
507+
Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
508+
getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
509+
assertThat(redirectWebFilter.isPresent()).isTrue();
510+
511+
FilteringWebHandler webHandler = new FilteringWebHandler(
512+
e -> Mono.error(new ClientAuthorizationRequiredException(registrationId)),
513+
Collections.singletonList(redirectWebFilter.get())
514+
);
515+
WebTestClient client = WebTestClient.bindToWebHandler(webHandler).build();
516+
client.get().uri("/foo/bar").exchange();
517+
verify(requestCacheMock, times(1)).saveRequest(any(ServerWebExchange.class));
518+
}
519+
520+
@Test(expected = IllegalArgumentException.class)
521+
public void throwExceptionWhenNullPassedForOAuth2LoginAuthorizationRequestRepository() {
522+
http.oauth2Login().authorizationRequestRepository(null).and().build();
523+
}
524+
525+
@SuppressWarnings({"UnassignedFluxMonoInstance", "unchecked"})
526+
@Test
527+
public void configureOAuth2LoginUsingCustomAuthorizationRequestRepository() {
528+
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
529+
String registrationId = clientRegistration.getRegistrationId();
530+
531+
ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
532+
mock(ReactiveClientRegistrationRepository.class);
533+
when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
534+
.thenReturn(Mono.just(clientRegistration));
535+
536+
ServerAuthorizationRequestRepository requestRepositoryMock = mock(ServerAuthorizationRequestRepository.class);
537+
SecurityWebFilterChain filterChain = http.oauth2Login()
538+
.clientRegistrationRepository(clientRegistrationRepositoryMock)
539+
.authorizationRequestRepository(requestRepositoryMock)
540+
.and().build();
541+
542+
Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
543+
getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
544+
assertThat(redirectWebFilter.isPresent()).isTrue();
545+
546+
WebTestClient client = WebTestClient.bindToController(new SubscriberContextController())
547+
.webFilter(redirectWebFilter.get())
548+
.build();
549+
client.get().uri("/oauth2/authorization/" + registrationId).exchange();
550+
verify(requestRepositoryMock, times(1)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class),
551+
any(ServerWebExchange.class));
552+
}
553+
478554
private boolean isX509Filter(WebFilter filter) {
479555
try {
480556
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");

0 commit comments

Comments
 (0)