Skip to content

Commit 556891b

Browse files
Merge branch '6.0.x'
Closes gh-12512
2 parents f981f70 + d1fc789 commit 556891b

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java

+19
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
6060
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
6161
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
62+
import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
63+
import org.springframework.security.web.context.SecurityContextRepository;
6264
import org.springframework.security.web.util.UrlUtils;
6365
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
6466
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -146,6 +148,8 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
146148

147149
private AuthenticationFailureHandler failureHandler;
148150

151+
private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository();
152+
149153
@Override
150154
public void afterPropertiesSet() {
151155
Assert.notNull(this.userDetailsService, "userDetailsService must be specified");
@@ -183,6 +187,7 @@ private void doFilter(HttpServletRequest request, HttpServletResponse response,
183187
context.setAuthentication(targetUser);
184188
this.securityContextHolderStrategy.setContext(context);
185189
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", targetUser));
190+
this.securityContextRepository.saveContext(context, request, response);
186191
// redirect to target url
187192
this.successHandler.onAuthenticationSuccess(request, response, targetUser);
188193
}
@@ -200,6 +205,7 @@ private void doFilter(HttpServletRequest request, HttpServletResponse response,
200205
context.setAuthentication(originalUser);
201206
this.securityContextHolderStrategy.setContext(context);
202207
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", originalUser));
208+
this.securityContextRepository.saveContext(context, request, response);
203209
// redirect to target url
204210
this.successHandler.onAuthenticationSuccess(request, response, originalUser);
205211
return;
@@ -525,6 +531,19 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
525531
this.securityContextHolderStrategy = securityContextHolderStrategy;
526532
}
527533

534+
/**
535+
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
536+
* switch user success. The default is
537+
* {@link RequestAttributeSecurityContextRepository}.
538+
* @param securityContextRepository the {@link SecurityContextRepository} to use.
539+
* Cannot be null.
540+
* @since 5.7.7
541+
*/
542+
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
543+
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
544+
this.securityContextRepository = securityContextRepository;
545+
}
546+
528547
private static RequestMatcher createMatcher(String pattern) {
529548
return new AntPathRequestMatcher(pattern, "POST", true, new UrlPathHelper());
530549
}

web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java

+60
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616

1717
package org.springframework.security.web.authentication.switchuser;
1818

19+
import java.io.IOException;
1920
import java.util.ArrayList;
2021
import java.util.List;
2122

2223
import jakarta.servlet.FilterChain;
24+
import jakarta.servlet.ServletException;
2325
import org.junit.jupiter.api.AfterEach;
2426
import org.junit.jupiter.api.BeforeEach;
2527
import org.junit.jupiter.api.Test;
2628

29+
import org.springframework.mock.web.MockFilterChain;
2730
import org.springframework.mock.web.MockHttpServletRequest;
2831
import org.springframework.mock.web.MockHttpServletResponse;
2932
import org.springframework.security.authentication.AccountExpiredException;
@@ -46,11 +49,15 @@
4649
import org.springframework.security.util.FieldUtils;
4750
import org.springframework.security.web.DefaultRedirectStrategy;
4851
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
52+
import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
53+
import org.springframework.security.web.context.SecurityContextRepository;
4954
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
55+
import org.springframework.test.util.ReflectionTestUtils;
5056

5157
import static org.assertj.core.api.Assertions.assertThat;
5258
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
5359
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
60+
import static org.mockito.ArgumentMatchers.any;
5461
import static org.mockito.Mockito.atLeastOnce;
5562
import static org.mockito.Mockito.mock;
5663
import static org.mockito.Mockito.never;
@@ -502,6 +509,59 @@ public void setSwitchFailureUrlWhenValidThenNoException() {
502509
filter.setSwitchFailureUrl("/foo");
503510
}
504511

512+
@Test
513+
void filterWhenDefaultSecurityContextRepositoryThenRequestAttributeRepository() {
514+
SwitchUserFilter switchUserFilter = new SwitchUserFilter();
515+
assertThat(ReflectionTestUtils.getField(switchUserFilter, "securityContextRepository"))
516+
.isInstanceOf(RequestAttributeSecurityContextRepository.class);
517+
}
518+
519+
@Test
520+
void doFilterWhenSwitchUserThenSaveSecurityContext() throws ServletException, IOException {
521+
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
522+
MockHttpServletRequest request = new MockHttpServletRequest();
523+
MockHttpServletResponse response = new MockHttpServletResponse();
524+
MockFilterChain filterChain = new MockFilterChain();
525+
request.setParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord");
526+
request.setRequestURI("/login/impersonate");
527+
SwitchUserFilter filter = new SwitchUserFilter();
528+
filter.setSecurityContextRepository(securityContextRepository);
529+
filter.setUserDetailsService(new MockUserDetailsService());
530+
filter.setTargetUrl("/target");
531+
filter.afterPropertiesSet();
532+
533+
filter.doFilter(request, response, filterChain);
534+
535+
verify(securityContextRepository).saveContext(any(), any(), any());
536+
}
537+
538+
@Test
539+
void doFilterWhenExitUserThenSaveSecurityContext() throws ServletException, IOException {
540+
UsernamePasswordAuthenticationToken source = UsernamePasswordAuthenticationToken.authenticated("dano",
541+
"hawaii50", ROLES_12);
542+
// set current user (Admin)
543+
List<GrantedAuthority> adminAuths = new ArrayList<>(ROLES_12);
544+
adminAuths.add(new SwitchUserGrantedAuthority("PREVIOUS_ADMINISTRATOR", source));
545+
UsernamePasswordAuthenticationToken admin = UsernamePasswordAuthenticationToken.authenticated("jacklord",
546+
"hawaii50", adminAuths);
547+
SecurityContextHolder.getContext().setAuthentication(admin);
548+
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
549+
MockHttpServletRequest request = new MockHttpServletRequest();
550+
MockHttpServletResponse response = new MockHttpServletResponse();
551+
MockFilterChain filterChain = new MockFilterChain();
552+
request.setParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord");
553+
request.setRequestURI("/logout/impersonate");
554+
SwitchUserFilter filter = new SwitchUserFilter();
555+
filter.setSecurityContextRepository(securityContextRepository);
556+
filter.setUserDetailsService(new MockUserDetailsService());
557+
filter.setTargetUrl("/target");
558+
filter.afterPropertiesSet();
559+
560+
filter.doFilter(request, response, filterChain);
561+
562+
verify(securityContextRepository).saveContext(any(), any(), any());
563+
}
564+
505565
private class MockUserDetailsService implements UserDetailsService {
506566

507567
private String password = "hawaii50";

0 commit comments

Comments
 (0)