Skip to content

Commit fa9898d

Browse files
dadikovirwinch
authored andcommitted
formLogin() and login() implement Mergable
This is necessary so that default requests like Spring REST Docs work. Closes gh-7572
1 parent bff6d82 commit fa9898d

File tree

4 files changed

+141
-16
lines changed

4 files changed

+141
-16
lines changed

test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java

+66-11
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@
1515
*/
1616
package org.springframework.security.test.web.servlet.request;
1717

18-
import javax.servlet.ServletContext;
19-
18+
import org.springframework.beans.Mergeable;
2019
import org.springframework.http.MediaType;
2120
import org.springframework.mock.web.MockHttpServletRequest;
2221
import org.springframework.security.web.csrf.CsrfToken;
2322
import org.springframework.test.web.servlet.MockMvc;
2423
import org.springframework.test.web.servlet.RequestBuilder;
24+
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
2525
import org.springframework.test.web.servlet.request.RequestPostProcessor;
2626
import org.springframework.web.util.UriComponentsBuilder;
2727

28+
import javax.servlet.ServletContext;
29+
2830
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
2931
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
3032

@@ -86,15 +88,23 @@ public static LogoutRequestBuilder logout(String logoutUrl) {
8688
* @author Rob Winch
8789
* @since 4.0
8890
*/
89-
public static final class LogoutRequestBuilder implements RequestBuilder {
91+
public static final class LogoutRequestBuilder implements RequestBuilder, Mergeable {
9092
private String logoutUrl = "/logout";
9193
private RequestPostProcessor postProcessor = csrf();
94+
private Mergeable parent;
9295

9396
@Override
9497
public MockHttpServletRequest buildRequest(ServletContext servletContext) {
95-
MockHttpServletRequest request = post(this.logoutUrl)
96-
.accept(MediaType.TEXT_HTML, MediaType.ALL)
97-
.buildRequest(servletContext);
98+
MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl)
99+
.accept(MediaType.TEXT_HTML, MediaType.ALL);
100+
101+
if (this.parent != null) {
102+
logoutRequest = (MockHttpServletRequestBuilder) logoutRequest.merge(this.parent);
103+
}
104+
105+
MockHttpServletRequest request = logoutRequest.buildRequest(servletContext);
106+
logoutRequest.postProcessRequest(request);
107+
98108
return this.postProcessor.postProcessRequest(request);
99109
}
100110

@@ -122,6 +132,24 @@ public LogoutRequestBuilder logoutUrl(String logoutUrl, Object... uriVars) {
122132
return this;
123133
}
124134

135+
@Override
136+
public boolean isMergeEnabled() {
137+
return true;
138+
}
139+
140+
@Override
141+
public Object merge(Object parent) {
142+
if (parent == null) {
143+
return this;
144+
}
145+
if (parent instanceof Mergeable) {
146+
this.parent = (Mergeable) parent;
147+
return this;
148+
} else {
149+
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
150+
}
151+
}
152+
125153
private LogoutRequestBuilder() {
126154
}
127155
}
@@ -132,22 +160,31 @@ private LogoutRequestBuilder() {
132160
* @author Rob Winch
133161
* @since 4.0
134162
*/
135-
public static final class FormLoginRequestBuilder implements RequestBuilder {
163+
public static final class FormLoginRequestBuilder implements RequestBuilder, Mergeable {
136164
private String usernameParam = "username";
137165
private String passwordParam = "password";
138166
private String username = "user";
139167
private String password = "password";
140168
private String loginProcessingUrl = "/login";
141169
private MediaType acceptMediaType = MediaType.APPLICATION_FORM_URLENCODED;
170+
private Mergeable parent;
142171

143172
private RequestPostProcessor postProcessor = csrf();
144173

145174
@Override
146175
public MockHttpServletRequest buildRequest(ServletContext servletContext) {
147-
MockHttpServletRequest request = post(this.loginProcessingUrl)
148-
.accept(this.acceptMediaType).param(this.usernameParam, this.username)
149-
.param(this.passwordParam, this.password)
150-
.buildRequest(servletContext);
176+
MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl)
177+
.accept(this.acceptMediaType)
178+
.param(this.usernameParam, this.username)
179+
.param(this.passwordParam, this.password);
180+
181+
if (this.parent != null) {
182+
loginRequest = (MockHttpServletRequestBuilder) loginRequest.merge(this.parent);
183+
}
184+
185+
MockHttpServletRequest request = loginRequest.buildRequest(servletContext);
186+
loginRequest.postProcessRequest(request);
187+
151188
return this.postProcessor.postProcessRequest(request);
152189
}
153190

@@ -258,6 +295,24 @@ public FormLoginRequestBuilder acceptMediaType(MediaType acceptMediaType) {
258295
return this;
259296
}
260297

298+
@Override
299+
public boolean isMergeEnabled() {
300+
return true;
301+
}
302+
303+
@Override
304+
public Object merge(Object parent) {
305+
if (parent == null) {
306+
return this;
307+
}
308+
if (parent instanceof Mergeable ) {
309+
this.parent = (Mergeable) parent;
310+
return this;
311+
} else {
312+
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
313+
}
314+
}
315+
261316
private FormLoginRequestBuilder() {
262317
}
263318
}

test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ static class TestCsrfTokenRepository implements CsrfTokenRepository {
410410

411411
private final CsrfTokenRepository delegate;
412412

413-
private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
413+
TestCsrfTokenRepository(CsrfTokenRepository delegate) {
414414
this.delegate = delegate;
415415
}
416416

test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

+37-1
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,25 @@
1717

1818
import org.junit.Before;
1919
import org.junit.Test;
20-
20+
import org.springframework.http.HttpMethod;
2121
import org.springframework.http.MediaType;
2222
import org.springframework.mock.web.MockHttpServletRequest;
2323
import org.springframework.mock.web.MockServletContext;
2424
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
2525
import org.springframework.security.web.csrf.CsrfToken;
26+
import org.springframework.test.web.servlet.MockMvc;
27+
import org.springframework.test.web.servlet.MvcResult;
28+
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
29+
import org.springframework.test.web.servlet.request.RequestPostProcessor;
30+
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
31+
32+
import java.util.Arrays;
2633

2734
import static org.assertj.core.api.Assertions.assertThat;
35+
import static org.mockito.ArgumentMatchers.any;
36+
import static org.mockito.Mockito.mock;
37+
import static org.mockito.Mockito.verify;
38+
import static org.powermock.api.mockito.PowerMockito.when;
2839
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
2940

3041
public class SecurityMockMvcRequestBuildersFormLoginTests {
@@ -82,6 +93,31 @@ public void customWithUriVars() {
8293
assertThat(request.getRequestURI()).isEqualTo("/uri-login/val1/val2");
8394
}
8495

96+
/**
97+
* spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together
98+
* with our request builders. (gh-7572)
99+
* @throws Exception
100+
*/
101+
@Test
102+
public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception {
103+
RequestPostProcessor postProcessor = mock(RequestPostProcessor.class);
104+
when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0));
105+
MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object())
106+
.defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor))
107+
.build();
108+
109+
110+
MvcResult mvcResult = mockMvc.perform(formLogin()).andReturn();
111+
assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name());
112+
assertThat(mvcResult.getRequest().getHeader("Accept"))
113+
.isEqualTo(MediaType.toString(Arrays.asList(MediaType.APPLICATION_FORM_URLENCODED)));
114+
assertThat(mvcResult.getRequest().getParameter("username")).isEqualTo("user");
115+
assertThat(mvcResult.getRequest().getParameter("password")).isEqualTo("password");
116+
assertThat(mvcResult.getRequest().getRequestURI()).isEqualTo("/login");
117+
assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty();
118+
verify(postProcessor).postProcessRequest(any());
119+
}
120+
85121
// gh-3920
86122
@Test
87123
public void usesAcceptMediaForContentNegotiation() {

test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

+37-3
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,28 @@
1515
*/
1616
package org.springframework.security.test.web.servlet.request;
1717

18-
import static org.assertj.core.api.Assertions.assertThat;
19-
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
20-
2118
import org.junit.Before;
2219
import org.junit.Test;
20+
import org.springframework.http.HttpMethod;
21+
import org.springframework.http.MediaType;
2322
import org.springframework.mock.web.MockHttpServletRequest;
2423
import org.springframework.mock.web.MockServletContext;
2524
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
2625
import org.springframework.security.web.csrf.CsrfToken;
26+
import org.springframework.test.web.servlet.MockMvc;
27+
import org.springframework.test.web.servlet.MvcResult;
28+
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
29+
import org.springframework.test.web.servlet.request.RequestPostProcessor;
30+
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
31+
32+
import java.util.Arrays;
33+
34+
import static org.assertj.core.api.Assertions.assertThat;
35+
import static org.mockito.ArgumentMatchers.any;
36+
import static org.mockito.Mockito.mock;
37+
import static org.mockito.Mockito.verify;
38+
import static org.powermock.api.mockito.PowerMockito.when;
39+
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
2740

2841
public class SecurityMockMvcRequestBuildersFormLogoutTests {
2942
private MockServletContext servletContext;
@@ -71,4 +84,25 @@ public void customWithUriVars() {
7184
assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");
7285
}
7386

87+
/**
88+
* spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together
89+
* with our request builders. (gh-7572)
90+
* @throws Exception
91+
*/
92+
@Test
93+
public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception {
94+
RequestPostProcessor postProcessor = mock(RequestPostProcessor.class);
95+
when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0));
96+
MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object())
97+
.defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor))
98+
.build();
99+
100+
MvcResult mvcResult = mockMvc.perform(logout()).andReturn();
101+
assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name());
102+
assertThat(mvcResult.getRequest().getHeader("Accept"))
103+
.isEqualTo(MediaType.toString(Arrays.asList(MediaType.TEXT_HTML, MediaType.ALL)));
104+
assertThat(mvcResult.getRequest().getRequestURI()).isEqualTo("/logout");
105+
assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty();
106+
verify(postProcessor).postProcessRequest(any());
107+
}
74108
}

0 commit comments

Comments
 (0)