|
20 | 20 | import static org.mockito.BDDMockito.given;
|
21 | 21 | import static org.mockito.ArgumentMatchers.any;
|
22 | 22 | import static org.mockito.Mockito.mock;
|
| 23 | +import static org.mockito.Mockito.times; |
23 | 24 | import static org.mockito.Mockito.verify;
|
24 | 25 | import static org.mockito.Mockito.verifyZeroInteractions;
|
25 | 26 | import static org.mockito.Mockito.when;
|
26 | 27 | import static org.springframework.security.config.Customizer.withDefaults;
|
27 | 28 |
|
28 | 29 | import java.util.Arrays;
|
| 30 | +import java.util.Collections; |
29 | 31 | import java.util.List;
|
30 | 32 | import java.util.Objects;
|
31 | 33 | import java.util.Optional;
|
|
41 | 43 | import org.springframework.security.core.Authentication;
|
42 | 44 | import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
|
43 | 45 | import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
|
| 46 | +import org.springframework.web.server.handler.FilteringWebHandler; |
44 | 47 | import reactor.core.publisher.Mono;
|
45 | 48 | import reactor.test.publisher.TestPublisher;
|
46 | 49 |
|
47 | 50 | import org.springframework.security.authentication.ReactiveAuthenticationManager;
|
48 | 51 | import org.springframework.security.authentication.TestingAuthenticationToken;
|
49 | 52 | import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
|
50 | 53 | import org.springframework.security.core.context.SecurityContext;
|
| 54 | +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; |
| 55 | +import org.springframework.security.oauth2.client.registration.ClientRegistration; |
| 56 | +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; |
| 57 | +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; |
| 58 | +import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; |
| 59 | +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; |
| 60 | +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; |
51 | 61 | import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
|
52 | 62 | import org.springframework.security.web.server.SecurityWebFilterChain;
|
53 | 63 | import org.springframework.security.web.server.WebFilterChainProxy;
|
| 64 | +import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; |
| 65 | +import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; |
54 | 66 | import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
|
55 | 67 | import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
|
56 | 68 | import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
|
57 | 69 | import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
|
| 70 | +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; |
58 | 71 | import org.springframework.security.web.server.context.ServerSecurityContextRepository;
|
59 | 72 | import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
|
60 | 73 | import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
|
61 | 74 | import org.springframework.security.web.server.csrf.CsrfWebFilter;
|
62 | 75 | import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
|
| 76 | +import org.springframework.security.web.server.savedrequest.ServerRequestCache; |
63 | 77 | import org.springframework.test.util.ReflectionTestUtils;
|
64 | 78 | import org.springframework.test.web.reactive.server.EntityExchangeResult;
|
65 | 79 | import org.springframework.test.web.reactive.server.FluxExchangeResult;
|
|
68 | 82 | import org.springframework.web.bind.annotation.RestController;
|
69 | 83 | import org.springframework.web.server.ServerWebExchange;
|
70 | 84 | import org.springframework.web.server.WebFilter;
|
71 |
| -import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; |
72 | 85 | import org.springframework.web.server.WebFilterChain;
|
73 |
| -import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; |
74 |
| -import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; |
75 | 86 |
|
76 | 87 | /**
|
77 | 88 | * @author Rob Winch
|
@@ -475,6 +486,71 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
|
475 | 486 | verify(customServerCsrfTokenRepository).loadToken(any());
|
476 | 487 | }
|
477 | 488 |
|
| 489 | + @SuppressWarnings("UnassignedFluxMonoInstance") |
| 490 | + @Test |
| 491 | + public void configureOAuth2LoginUsingCustomCommonServerRequestCache() { |
| 492 | + ServerRequestCache requestCacheMock = mock(ServerRequestCache.class); |
| 493 | + when(requestCacheMock.saveRequest(any(ServerWebExchange.class))).thenReturn(Mono.empty()); |
| 494 | + |
| 495 | + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); |
| 496 | + String registrationId = clientRegistration.getRegistrationId(); |
| 497 | + |
| 498 | + ReactiveClientRegistrationRepository clientRegistrationRepositoryMock = |
| 499 | + mock(ReactiveClientRegistrationRepository.class); |
| 500 | + when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId)) |
| 501 | + .thenReturn(Mono.just(clientRegistration)); |
| 502 | + |
| 503 | + SecurityWebFilterChain filterChain = http.requestCache().requestCache(requestCacheMock) |
| 504 | + .and().oauth2Login().clientRegistrationRepository(clientRegistrationRepositoryMock) |
| 505 | + .and().build(); |
| 506 | + |
| 507 | + Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter = |
| 508 | + getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class); |
| 509 | + assertThat(redirectWebFilter.isPresent()).isTrue(); |
| 510 | + |
| 511 | + FilteringWebHandler webHandler = new FilteringWebHandler( |
| 512 | + e -> Mono.error(new ClientAuthorizationRequiredException(registrationId)), |
| 513 | + Collections.singletonList(redirectWebFilter.get()) |
| 514 | + ); |
| 515 | + WebTestClient client = WebTestClient.bindToWebHandler(webHandler).build(); |
| 516 | + client.get().uri("/foo/bar").exchange(); |
| 517 | + verify(requestCacheMock, times(1)).saveRequest(any(ServerWebExchange.class)); |
| 518 | + } |
| 519 | + |
| 520 | + @Test(expected = IllegalArgumentException.class) |
| 521 | + public void throwExceptionWhenNullPassedForOAuth2LoginAuthorizationRequestRepository() { |
| 522 | + http.oauth2Login().authorizationRequestRepository(null).and().build(); |
| 523 | + } |
| 524 | + |
| 525 | + @SuppressWarnings({"UnassignedFluxMonoInstance", "unchecked"}) |
| 526 | + @Test |
| 527 | + public void configureOAuth2LoginUsingCustomAuthorizationRequestRepository() { |
| 528 | + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); |
| 529 | + String registrationId = clientRegistration.getRegistrationId(); |
| 530 | + |
| 531 | + ReactiveClientRegistrationRepository clientRegistrationRepositoryMock = |
| 532 | + mock(ReactiveClientRegistrationRepository.class); |
| 533 | + when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId)) |
| 534 | + .thenReturn(Mono.just(clientRegistration)); |
| 535 | + |
| 536 | + ServerAuthorizationRequestRepository requestRepositoryMock = mock(ServerAuthorizationRequestRepository.class); |
| 537 | + SecurityWebFilterChain filterChain = http.oauth2Login() |
| 538 | + .clientRegistrationRepository(clientRegistrationRepositoryMock) |
| 539 | + .authorizationRequestRepository(requestRepositoryMock) |
| 540 | + .and().build(); |
| 541 | + |
| 542 | + Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter = |
| 543 | + getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class); |
| 544 | + assertThat(redirectWebFilter.isPresent()).isTrue(); |
| 545 | + |
| 546 | + WebTestClient client = WebTestClient.bindToController(new SubscriberContextController()) |
| 547 | + .webFilter(redirectWebFilter.get()) |
| 548 | + .build(); |
| 549 | + client.get().uri("/oauth2/authorization/" + registrationId).exchange(); |
| 550 | + verify(requestRepositoryMock, times(1)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class), |
| 551 | + any(ServerWebExchange.class)); |
| 552 | + } |
| 553 | + |
478 | 554 | private boolean isX509Filter(WebFilter filter) {
|
479 | 555 | try {
|
480 | 556 | Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
|
|
0 commit comments