Skip to content

Allow to set default securityContextRepository for each authenticatio… #7275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ public class ServerHttpSecurity {

private ReactiveAuthenticationManager authenticationManager;

private ServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
private ServerSecurityContextRepository securityContextRepository;

private ServerAuthenticationEntryPoint authenticationEntryPoint;

Expand Down Expand Up @@ -346,7 +346,7 @@ private ServerWebExchangeMatcher getSecurityMatcher() {
}

/**
* The strategy used with {@code ReactorContextWebFilter}. It does not impact how the {@code SecurityContext} is
* The strategy used with {@code ReactorContextWebFilter}. It does impact how the {@code SecurityContext} is
* saved which is configured on a per {@link AuthenticationWebFilter} basis.
* @param securityContextRepository the repository to use
* @return the {@link ServerHttpSecurity} to continue configuring
Expand Down Expand Up @@ -971,7 +971,7 @@ public class OAuth2LoginSpec {

private ReactiveAuthenticationManager authenticationManager;

private ServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
private ServerSecurityContextRepository securityContextRepository;

private ServerAuthenticationConverter authenticationConverter;

Expand Down Expand Up @@ -2254,9 +2254,7 @@ public SecurityWebFilterChain build() {
this.headers.configure(this);
}
WebFilter securityContextRepositoryWebFilter = securityContextRepositoryWebFilter();
if (securityContextRepositoryWebFilter != null) {
this.webFilters.add(securityContextRepositoryWebFilter);
}
this.webFilters.add(securityContextRepositoryWebFilter);
if (this.httpsRedirectSpec != null) {
this.httpsRedirectSpec.configure(this);
}
Expand All @@ -2273,18 +2271,42 @@ public SecurityWebFilterChain build() {
if (this.httpBasic.authenticationManager == null) {
this.httpBasic.authenticationManager(this.authenticationManager);
}
if (this.httpBasic.securityContextRepository != null) {
this.httpBasic.securityContextRepository(this.httpBasic.securityContextRepository);
}
else if (this.securityContextRepository != null) {
this.httpBasic.securityContextRepository(this.securityContextRepository);
}
else {
this.httpBasic.securityContextRepository(NoOpServerSecurityContextRepository.getInstance());
}
this.httpBasic.configure(this);
}
if (this.formLogin != null) {
if (this.formLogin.authenticationManager == null) {
this.formLogin.authenticationManager(this.authenticationManager);
}
if (this.securityContextRepository != null) {
if (this.formLogin.securityContextRepository != null) {
this.formLogin.securityContextRepository(this.formLogin.securityContextRepository);
}
else if (this.securityContextRepository != null) {
this.formLogin.securityContextRepository(this.securityContextRepository);
}
else {
this.formLogin.securityContextRepository(new WebSessionServerSecurityContextRepository());
}
this.formLogin.configure(this);
}
if (this.oauth2Login != null) {
if (this.oauth2Login.securityContextRepository != null) {
this.oauth2Login.securityContextRepository(this.oauth2Login.securityContextRepository);
}
else if (this.securityContextRepository != null) {
this.oauth2Login.securityContextRepository(this.securityContextRepository);
}
else {
this.oauth2Login.securityContextRepository(new WebSessionServerSecurityContextRepository());
}
this.oauth2Login.configure(this);
}
if (this.resourceServer != null) {
Expand Down Expand Up @@ -2379,10 +2401,8 @@ public static ServerHttpSecurity http() {
}

private WebFilter securityContextRepositoryWebFilter() {
ServerSecurityContextRepository repository = this.securityContextRepository;
if (repository == null) {
return null;
}
ServerSecurityContextRepository repository = this.securityContextRepository == null ?
new WebSessionServerSecurityContextRepository() : this.securityContextRepository;
WebFilter result = new ReactorContextWebFilter(repository);
return new OrderedWebFilter(result, SecurityWebFiltersOrder.REACTOR_CONTEXT.getOrder());
}
Expand Down Expand Up @@ -2774,7 +2794,7 @@ private RequestCacheSpec() {}
public class HttpBasicSpec {
private ReactiveAuthenticationManager authenticationManager;

private ServerSecurityContextRepository securityContextRepository = NoOpServerSecurityContextRepository.getInstance();
private ServerSecurityContextRepository securityContextRepository;

private ServerAuthenticationEntryPoint entryPoint = new HttpBasicServerAuthenticationEntryPoint();

Expand Down Expand Up @@ -2846,9 +2866,7 @@ protected void configure(ServerHttpSecurity http) {
this.authenticationManager);
authenticationFilter.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint));
authenticationFilter.setAuthenticationConverter(new ServerHttpBasicAuthenticationConverter());
if (this.securityContextRepository != null) {
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
}
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC);
}

Expand All @@ -2869,7 +2887,7 @@ public class FormLoginSpec {

private ReactiveAuthenticationManager authenticationManager;

private ServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
private ServerSecurityContextRepository securityContextRepository;

private ServerAuthenticationEntryPoint authenticationEntryPoint;

Expand Down Expand Up @@ -2966,7 +2984,7 @@ public FormLoginSpec authenticationFailureHandler(ServerAuthenticationFailureHan

/**
* The {@link ServerSecurityContextRepository} used to save the {@code Authentication}. Defaults to
* {@link NoOpServerSecurityContextRepository}. For the {@code SecurityContext} to be loaded on subsequent
* {@link WebSessionServerSecurityContextRepository}. For the {@code SecurityContext} to be loaded on subsequent
* requests the {@link ReactorContextWebFilter} must be configured to be able to load the value (they are not
* implicitly linked).
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.csrf.CsrfToken;
import org.springframework.stereotype.Controller;
import org.springframework.test.web.reactive.server.WebTestClient;
Expand All @@ -44,12 +48,15 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.springframework.security.config.Customizer.withDefaults;

/**
* @author Rob Winch
* @author Eddú Meléndez
* @since 5.0
*/
public class FormLoginTests {
Expand Down Expand Up @@ -272,6 +279,50 @@ public void customAuthenticationManager() {
verifyZeroInteractions(defaultAuthenticationManager);
}

@Test
public void formLoginSecurityContextRepository() {
ServerSecurityContextRepository defaultSecContextRepository = mock(ServerSecurityContextRepository.class);
ServerSecurityContextRepository formLoginSecContextRepository = mock(ServerSecurityContextRepository.class);

TestingAuthenticationToken token = new TestingAuthenticationToken("rob", "rob", "ROLE_USER");

given(defaultSecContextRepository.save(any(), any())).willReturn(Mono.empty());
given(defaultSecContextRepository.load(any())).willReturn(authentication(token));
given(formLoginSecContextRepository.save(any(), any())).willReturn(Mono.empty());
given(formLoginSecContextRepository.load(any())).willReturn(authentication(token));

SecurityWebFilterChain securityWebFilter = this.http
.authorizeExchange()
.anyExchange().authenticated()
.and()
.securityContextRepository(defaultSecContextRepository)
.formLogin()
.securityContextRepository(formLoginSecContextRepository)
.and()
.build();

WebTestClient webTestClient = WebTestClientBuilder
.bindToWebFilters(securityWebFilter)
.build();

WebDriver driver = WebTestClientHtmlUnitDriverBuilder
.webTestClientSetup(webTestClient)
.build();

DefaultLoginPage loginPage = DefaultLoginPage.to(driver)
.assertAt();

HomePage homePage = loginPage.loginForm()
.username("user")
.password("password")
.submit(HomePage.class);

homePage.assertAt();

verify(defaultSecContextRepository, atLeastOnce()).load(any());
verify(formLoginSecContextRepository).save(any(), any());
}

public static class CustomLoginPage {

private WebDriver driver;
Expand Down Expand Up @@ -501,4 +552,10 @@ public Mono<String> login(ServerWebExchange exchange) {
+ "</html>");
}
}

Mono<SecurityContext> authentication(Authentication authentication) {
SecurityContext context = new SecurityContextImpl();
context.setAuthentication(authentication);
return Mono.just(context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ public void oauth2LoginWhenCustomBeansThenUsed() {

ServerSecurityContextRepository securityContextRepository = config.securityContextRepository;
when(securityContextRepository.save(any(), any())).thenReturn(Mono.empty());
when(securityContextRepository.load(any())).thenReturn(authentication(token));

Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

/**
* @author Rob Winch
* @author Eddú Meléndez
* @since 5.0
*/
@RunWith(MockitoJUnitRunner.class)
Expand Down Expand Up @@ -117,7 +118,6 @@ public void defaults() {
public void basic() {
given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN")));

this.http.securityContextRepository(new WebSessionServerSecurityContextRepository());
this.http.httpBasic();
this.http.authenticationManager(this.authenticationManager);
ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange();
Expand All @@ -137,6 +137,30 @@ public void basic() {
assertThat(result.getResponseCookies().getFirst("SESSION")).isNull();
}

@Test
public void basicWithGlobalWebSessionServerSecurityContextRepository() {
given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN")));

this.http.securityContextRepository(new WebSessionServerSecurityContextRepository());
this.http.httpBasic();
this.http.authenticationManager(this.authenticationManager);
ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange();
authorize.anyExchange().authenticated();

WebTestClient client = buildClient();

EntityExchangeResult<String> result = client.get()
.uri("/")
.headers(headers -> headers.setBasicAuth("rob", "rob"))
.exchange()
.expectStatus().isOk()
.expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+")
.expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok"))
.returnResult();

assertThat(result.getResponseCookies().getFirst("SESSION")).isNotNull();
}

@Test
public void basicWhenNoCredentialsThenUnauthorized() {
this.http.authorizeExchange().anyExchange().authenticated();
Expand Down Expand Up @@ -256,7 +280,6 @@ public void getWhenAnonymousConfiguredThenAuthenticationIsAnonymous() throws Exc
public void basicWithAnonymous() {
given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN")));

this.http.securityContextRepository(new WebSessionServerSecurityContextRepository());
this.http.httpBasic().and().anonymous();
this.http.authenticationManager(this.authenticationManager);
ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange();
Expand Down