|
16 | 16 |
|
17 | 17 | package org.springframework.security.web.context;
|
18 | 18 |
|
| 19 | +import java.io.IOException; |
19 | 20 | import java.lang.annotation.ElementType;
|
20 | 21 | import java.lang.annotation.Retention;
|
21 | 22 | import java.lang.annotation.RetentionPolicy;
|
22 | 23 | import java.lang.annotation.Target;
|
23 | 24 |
|
| 25 | +import javax.servlet.Filter; |
| 26 | +import javax.servlet.ServletException; |
24 | 27 | import javax.servlet.ServletOutputStream;
|
| 28 | +import javax.servlet.http.HttpServlet; |
25 | 29 | import javax.servlet.http.HttpServletRequest;
|
26 | 30 | import javax.servlet.http.HttpServletRequestWrapper;
|
27 | 31 | import javax.servlet.http.HttpServletResponse;
|
|
31 | 35 | import org.junit.After;
|
32 | 36 | import org.junit.Test;
|
33 | 37 |
|
| 38 | +import org.springframework.mock.web.MockFilterChain; |
34 | 39 | import org.springframework.mock.web.MockHttpServletRequest;
|
35 | 40 | import org.springframework.mock.web.MockHttpServletResponse;
|
36 | 41 | import org.springframework.mock.web.MockHttpSession;
|
37 | 42 | import org.springframework.security.authentication.AbstractAuthenticationToken;
|
38 | 43 | import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
39 | 44 | import org.springframework.security.authentication.AuthenticationTrustResolver;
|
40 | 45 | import org.springframework.security.authentication.TestingAuthenticationToken;
|
| 46 | +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; |
41 | 47 | import org.springframework.security.core.Transient;
|
42 | 48 | import org.springframework.security.core.authority.AuthorityUtils;
|
43 | 49 | import org.springframework.security.core.context.SecurityContext;
|
44 | 50 | import org.springframework.security.core.context.SecurityContextHolder;
|
| 51 | +import org.springframework.security.core.context.SecurityContextImpl; |
| 52 | +import org.springframework.security.core.userdetails.User; |
| 53 | +import org.springframework.security.core.userdetails.UserDetails; |
45 | 54 |
|
46 | 55 | import static org.assertj.core.api.Assertions.assertThat;
|
47 | 56 | import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
@@ -162,6 +171,48 @@ public void saveContextCallsSetAttributeIfContextIsModifiedDirectlyDuringRequest
|
162 | 171 | verify(session).setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, ctx);
|
163 | 172 | }
|
164 | 173 |
|
| 174 | + @Test |
| 175 | + public void saveContextWhenSaveNewContextThenOriginalContextThenOriginalContextSaved() throws Exception { |
| 176 | + HttpSessionSecurityContextRepository repository = new HttpSessionSecurityContextRepository(); |
| 177 | + SecurityContextPersistenceFilter securityContextPersistenceFilter = new SecurityContextPersistenceFilter( |
| 178 | + repository); |
| 179 | + |
| 180 | + UserDetails original = User.withUsername("user").password("password").roles("USER").build(); |
| 181 | + SecurityContext originalContext = createSecurityContext(original); |
| 182 | + UserDetails impersonate = User.withUserDetails(original).username("impersonate").build(); |
| 183 | + SecurityContext impersonateContext = createSecurityContext(impersonate); |
| 184 | + |
| 185 | + MockHttpServletRequest mockRequest = new MockHttpServletRequest(); |
| 186 | + MockHttpServletResponse mockResponse = new MockHttpServletResponse(); |
| 187 | + |
| 188 | + Filter saveImpersonateContext = (request, response, chain) -> { |
| 189 | + SecurityContextHolder.setContext(impersonateContext); |
| 190 | + // ensure the response is committed to trigger save |
| 191 | + response.flushBuffer(); |
| 192 | + chain.doFilter(request, response); |
| 193 | + }; |
| 194 | + Filter saveOriginalContext = (request, response, chain) -> { |
| 195 | + SecurityContextHolder.setContext(originalContext); |
| 196 | + chain.doFilter(request, response); |
| 197 | + }; |
| 198 | + HttpServlet servlet = new HttpServlet() { |
| 199 | + @Override |
| 200 | + protected void service(HttpServletRequest req, HttpServletResponse resp) |
| 201 | + throws ServletException, IOException { |
| 202 | + resp.getWriter().write("Hi"); |
| 203 | + } |
| 204 | + }; |
| 205 | + |
| 206 | + SecurityContextHolder.setContext(originalContext); |
| 207 | + MockFilterChain chain = new MockFilterChain(servlet, saveImpersonateContext, saveOriginalContext); |
| 208 | + |
| 209 | + securityContextPersistenceFilter.doFilter(mockRequest, mockResponse, chain); |
| 210 | + |
| 211 | + assertThat( |
| 212 | + mockRequest.getSession().getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY)) |
| 213 | + .isEqualTo(originalContext); |
| 214 | + } |
| 215 | + |
165 | 216 | @Test
|
166 | 217 | public void nonSecurityContextInSessionIsIgnored() {
|
167 | 218 | HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
|
@@ -577,6 +628,13 @@ public void saveContextWhenTransientAuthenticationWithCustomAnnotationThenSkippe
|
577 | 628 | assertThat(session).isNull();
|
578 | 629 | }
|
579 | 630 |
|
| 631 | + private SecurityContext createSecurityContext(UserDetails userDetails) { |
| 632 | + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(userDetails, |
| 633 | + userDetails.getPassword(), userDetails.getAuthorities()); |
| 634 | + SecurityContext securityContext = new SecurityContextImpl(token); |
| 635 | + return securityContext; |
| 636 | + } |
| 637 | + |
580 | 638 | @Transient
|
581 | 639 | private static class SomeTransientAuthentication extends AbstractAuthenticationToken {
|
582 | 640 |
|
|
0 commit comments