From 566a0aa95762306dc486ac3d9ce80428a87434f6 Mon Sep 17 00:00:00 2001 From: Roman Chigvintsev Date: Mon, 23 Sep 2019 14:12:16 +0300 Subject: [PATCH] Allow to customize OAuth2AuthorizationRequestRedirectWebFilter in OAuth2 login configuration Fixes gh-7466 --- .../config/web/server/ServerHttpSecurity.java | 29 +++++++ .../web/server/ServerHttpSecurityTests.java | 82 ++++++++++++++++++- 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 405121fbf72..90e92faa575 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -74,6 +74,7 @@ import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; @@ -969,6 +970,8 @@ public class OAuth2LoginSpec { private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ServerAuthorizationRequestRepository authorizationRequestRepository; + private ReactiveAuthenticationManager authenticationManager; private ServerSecurityContextRepository securityContextRepository; @@ -1099,6 +1102,18 @@ public OAuth2LoginSpec authorizedClientRepository(ServerOAuth2AuthorizedClientRe return this; } + /** + * Sets authorization request repository for {@link OAuth2AuthorizationRequestRedirectWebFilter}. + * + * @param authorizationRequestRepository authorization request repository, must not be null + * @return the {@link OAuth2LoginSpec} for further configuration + */ + public OAuth2LoginSpec authorizationRequestRepository(ServerAuthorizationRequestRepository authorizationRequestRepository) { + Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); + this.authorizationRequestRepository = authorizationRequestRepository; + return this; + } + /** * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s. * @@ -1143,6 +1158,12 @@ protected void configure(ServerHttpSecurity http) { ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(); ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository(); OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter(); + ServerAuthorizationRequestRepository authorizationRequestRepository = + getAuthorizationRequestRepository(); + if (authorizationRequestRepository != null) { + oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository); + } + oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); ReactiveAuthenticationManager manager = getAuthenticationManager(); @@ -1243,6 +1264,14 @@ private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() { return result; } + @SuppressWarnings("unchecked") + private ServerAuthorizationRequestRepository getAuthorizationRequestRepository() { + if (this.authorizationRequestRepository == null) { + this.authorizationRequestRepository = getBeanOrNull(ServerAuthorizationRequestRepository.class); + } + return this.authorizationRequestRepository; + } + private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() { ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); if (service == null) { diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index c95f8bd17d7..01be4c914da 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -20,12 +20,14 @@ import static org.mockito.BDDMockito.given; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -41,6 +43,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; +import org.springframework.web.server.handler.FilteringWebHandler; import reactor.core.publisher.Mono; import reactor.test.publisher.TestPublisher; @@ -48,18 +51,29 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; +import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; +import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.LogoutWebFilter; import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult; @@ -68,10 +82,7 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; -import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.web.server.WebFilterChain; -import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; -import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; /** * @author Rob Winch @@ -475,6 +486,71 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() { verify(customServerCsrfTokenRepository).loadToken(any()); } + @SuppressWarnings("UnassignedFluxMonoInstance") + @Test + public void configureOAuth2LoginUsingCustomCommonServerRequestCache() { + ServerRequestCache requestCacheMock = mock(ServerRequestCache.class); + when(requestCacheMock.saveRequest(any(ServerWebExchange.class))).thenReturn(Mono.empty()); + + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + String registrationId = clientRegistration.getRegistrationId(); + + ReactiveClientRegistrationRepository clientRegistrationRepositoryMock = + mock(ReactiveClientRegistrationRepository.class); + when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId)) + .thenReturn(Mono.just(clientRegistration)); + + SecurityWebFilterChain filterChain = http.requestCache().requestCache(requestCacheMock) + .and().oauth2Login().clientRegistrationRepository(clientRegistrationRepositoryMock) + .and().build(); + + Optional redirectWebFilter = + getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class); + assertThat(redirectWebFilter.isPresent()).isTrue(); + + FilteringWebHandler webHandler = new FilteringWebHandler( + e -> Mono.error(new ClientAuthorizationRequiredException(registrationId)), + Collections.singletonList(redirectWebFilter.get()) + ); + WebTestClient client = WebTestClient.bindToWebHandler(webHandler).build(); + client.get().uri("/foo/bar").exchange(); + verify(requestCacheMock, times(1)).saveRequest(any(ServerWebExchange.class)); + } + + @Test(expected = IllegalArgumentException.class) + public void throwExceptionWhenNullPassedForOAuth2LoginAuthorizationRequestRepository() { + http.oauth2Login().authorizationRequestRepository(null).and().build(); + } + + @SuppressWarnings({"UnassignedFluxMonoInstance", "unchecked"}) + @Test + public void configureOAuth2LoginUsingCustomAuthorizationRequestRepository() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + String registrationId = clientRegistration.getRegistrationId(); + + ReactiveClientRegistrationRepository clientRegistrationRepositoryMock = + mock(ReactiveClientRegistrationRepository.class); + when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId)) + .thenReturn(Mono.just(clientRegistration)); + + ServerAuthorizationRequestRepository requestRepositoryMock = mock(ServerAuthorizationRequestRepository.class); + SecurityWebFilterChain filterChain = http.oauth2Login() + .clientRegistrationRepository(clientRegistrationRepositoryMock) + .authorizationRequestRepository(requestRepositoryMock) + .and().build(); + + Optional redirectWebFilter = + getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class); + assertThat(redirectWebFilter.isPresent()).isTrue(); + + WebTestClient client = WebTestClient.bindToController(new SubscriberContextController()) + .webFilter(redirectWebFilter.get()) + .build(); + client.get().uri("/oauth2/authorization/" + registrationId).exchange(); + verify(requestRepositoryMock, times(1)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class), + any(ServerWebExchange.class)); + } + private boolean isX509Filter(WebFilter filter) { try { Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");