Skip to content

Allow to customize OAuth2AuthorizationRequestRedirectWebFilter in OAuth2LoginSpec #7467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
Expand Down Expand Up @@ -969,6 +970,8 @@ public class OAuth2LoginSpec {

private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;

private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;

private ReactiveAuthenticationManager authenticationManager;

private ServerSecurityContextRepository securityContextRepository;
Expand Down Expand Up @@ -1099,6 +1102,18 @@ public OAuth2LoginSpec authorizedClientRepository(ServerOAuth2AuthorizedClientRe
return this;
}

/**
* Sets authorization request repository for {@link OAuth2AuthorizationRequestRedirectWebFilter}.
*
* @param authorizationRequestRepository authorization request repository, must not be null
* @return the {@link OAuth2LoginSpec} for further configuration
*/
public OAuth2LoginSpec authorizationRequestRepository(ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository) {
Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null");
this.authorizationRequestRepository = authorizationRequestRepository;
return this;
}

/**
* Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
*
Expand Down Expand Up @@ -1143,6 +1158,12 @@ protected void configure(ServerHttpSecurity http) {
ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter();
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
getAuthorizationRequestRepository();
if (authorizationRequestRepository != null) {
oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository);
}
oauthRedirectFilter.setRequestCache(http.requestCache.requestCache);

ReactiveAuthenticationManager manager = getAuthenticationManager();

Expand Down Expand Up @@ -1243,6 +1264,14 @@ private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() {
return result;
}

@SuppressWarnings("unchecked")
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> getAuthorizationRequestRepository() {
if (this.authorizationRequestRepository == null) {
this.authorizationRequestRepository = getBeanOrNull(ServerAuthorizationRequestRepository.class);
}
return this.authorizationRequestRepository;
}

private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() {
ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class);
if (service == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import static org.mockito.BDDMockito.given;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand All @@ -41,25 +43,37 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
import org.springframework.web.server.handler.FilteringWebHandler;
import reactor.core.publisher.Mono;
import reactor.test.publisher.TestPublisher;

import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.reactive.server.EntityExchangeResult;
import org.springframework.test.web.reactive.server.FluxExchangeResult;
Expand All @@ -68,10 +82,7 @@
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;

/**
* @author Rob Winch
Expand Down Expand Up @@ -475,6 +486,71 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
verify(customServerCsrfTokenRepository).loadToken(any());
}

@SuppressWarnings("UnassignedFluxMonoInstance")
@Test
public void configureOAuth2LoginUsingCustomCommonServerRequestCache() {
ServerRequestCache requestCacheMock = mock(ServerRequestCache.class);
when(requestCacheMock.saveRequest(any(ServerWebExchange.class))).thenReturn(Mono.empty());

ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
String registrationId = clientRegistration.getRegistrationId();

ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
mock(ReactiveClientRegistrationRepository.class);
when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
.thenReturn(Mono.just(clientRegistration));

SecurityWebFilterChain filterChain = http.requestCache().requestCache(requestCacheMock)
.and().oauth2Login().clientRegistrationRepository(clientRegistrationRepositoryMock)
.and().build();

Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
assertThat(redirectWebFilter.isPresent()).isTrue();

FilteringWebHandler webHandler = new FilteringWebHandler(
e -> Mono.error(new ClientAuthorizationRequiredException(registrationId)),
Collections.singletonList(redirectWebFilter.get())
);
WebTestClient client = WebTestClient.bindToWebHandler(webHandler).build();
client.get().uri("/foo/bar").exchange();
verify(requestCacheMock, times(1)).saveRequest(any(ServerWebExchange.class));
}

@Test(expected = IllegalArgumentException.class)
public void throwExceptionWhenNullPassedForOAuth2LoginAuthorizationRequestRepository() {
http.oauth2Login().authorizationRequestRepository(null).and().build();
}

@SuppressWarnings({"UnassignedFluxMonoInstance", "unchecked"})
@Test
public void configureOAuth2LoginUsingCustomAuthorizationRequestRepository() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
String registrationId = clientRegistration.getRegistrationId();

ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
mock(ReactiveClientRegistrationRepository.class);
when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
.thenReturn(Mono.just(clientRegistration));

ServerAuthorizationRequestRepository requestRepositoryMock = mock(ServerAuthorizationRequestRepository.class);
SecurityWebFilterChain filterChain = http.oauth2Login()
.clientRegistrationRepository(clientRegistrationRepositoryMock)
.authorizationRequestRepository(requestRepositoryMock)
.and().build();

Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
assertThat(redirectWebFilter.isPresent()).isTrue();

WebTestClient client = WebTestClient.bindToController(new SubscriberContextController())
.webFilter(redirectWebFilter.get())
.build();
client.get().uri("/oauth2/authorization/" + registrationId).exchange();
verify(requestRepositoryMock, times(1)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class),
any(ServerWebExchange.class));
}

private boolean isX509Filter(WebFilter filter) {
try {
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
Expand Down