Skip to content

Commit 36e2dd9

Browse files
committed
Support contextPath override in ForwardedHeaderFilter
Issue: SPR-13614
1 parent 6fcc869 commit 36e2dd9

File tree

2 files changed

+203
-17
lines changed

2 files changed

+203
-17
lines changed

spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131

3232
import org.springframework.http.HttpRequest;
3333
import org.springframework.http.server.ServletServerHttpRequest;
34+
import org.springframework.util.Assert;
3435
import org.springframework.util.CollectionUtils;
3536
import org.springframework.web.util.UriComponents;
3637
import org.springframework.web.util.UriComponentsBuilder;
38+
import org.springframework.web.util.UrlPathHelper;
3739

3840

3941
/**
@@ -61,6 +63,28 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
6163
}
6264

6365

66+
private ContextPathHelper contextPathHelper;
67+
68+
69+
70+
/**
71+
* Configure a contextPath value that will replace the contextPath of
72+
* proxy-forwarded requests.
73+
*
74+
* <p>This is useful when external clients are not aware of the application
75+
* context path. However a proxy forwards the request to a URL that includes
76+
* a contextPath.
77+
*
78+
* @param contextPath the context path; the given value will be sanitized to
79+
* ensure it starts with a '/' but does not end with one, or if the context
80+
* path is empty (default, root context) it is left as-is.
81+
*/
82+
public void setContextPath(String contextPath) {
83+
Assert.notNull(contextPath, "'contextPath' must not be null");
84+
this.contextPathHelper = new ContextPathHelper(contextPath);
85+
}
86+
87+
6488
@Override
6589
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
6690
Enumeration<String> headerNames = request.getHeaderNames();
@@ -87,7 +111,7 @@ protected boolean shouldNotFilterErrorDispatch() {
87111
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
88112
FilterChain filterChain) throws ServletException, IOException {
89113

90-
filterChain.doFilter(new ForwardedHeaderRequestWrapper(request), response);
114+
filterChain.doFilter(new ForwardedHeaderRequestWrapper(request, this.contextPathHelper), response);
91115
}
92116

93117

@@ -105,12 +129,16 @@ private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWra
105129

106130
private final int port;
107131

132+
private final String contextPath;
133+
134+
private final String requestUri;
135+
108136
private final StringBuffer requestUrl;
109137

110138
private final Map<String, List<String>> headers;
111139

112140

113-
public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
141+
public ForwardedHeaderRequestWrapper(HttpServletRequest request, ContextPathHelper pathHelper) {
114142
super(request);
115143

116144
HttpRequest httpRequest = new ServletServerHttpRequest(request);
@@ -121,7 +149,11 @@ public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
121149
this.secure = "https".equals(scheme);
122150
this.host = uriComponents.getHost();
123151
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
124-
this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI());
152+
153+
this.contextPath = (pathHelper != null ? pathHelper.getContextPath(request) : request.getContextPath());
154+
this.requestUri = (pathHelper != null ? pathHelper.getRequestUri(request) : request.getRequestURI());
155+
this.requestUrl = initRequestUrl(this.scheme, this.host, port, this.requestUri);
156+
125157
this.headers = initHeaders(request);
126158
}
127159

@@ -170,6 +202,16 @@ public boolean isSecure() {
170202
return this.secure;
171203
}
172204

205+
@Override
206+
public String getContextPath() {
207+
return this.contextPath;
208+
}
209+
210+
@Override
211+
public String getRequestURI() {
212+
return this.requestUri;
213+
}
214+
173215
@Override
174216
public StringBuffer getRequestURL() {
175217
return this.requestUrl;
@@ -195,4 +237,50 @@ public Enumeration<String> getHeaderNames() {
195237
}
196238
}
197239

240+
241+
private static class ContextPathHelper {
242+
243+
private final String contextPath;
244+
245+
private final UrlPathHelper urlPathHelper;
246+
247+
248+
public ContextPathHelper(String contextPath) {
249+
Assert.notNull(contextPath);
250+
this.contextPath = sanitizeContextPath(contextPath);
251+
this.urlPathHelper = new UrlPathHelper();
252+
this.urlPathHelper.setUrlDecode(false);
253+
this.urlPathHelper.setRemoveSemicolonContent(false);
254+
}
255+
256+
private static String sanitizeContextPath(String contextPath) {
257+
contextPath = contextPath.trim();
258+
if (contextPath.isEmpty()) {
259+
return contextPath;
260+
}
261+
if (contextPath.equals("/")) {
262+
return "/";
263+
}
264+
if (contextPath.charAt(0) != '/') {
265+
contextPath = "/" + contextPath;
266+
}
267+
while (contextPath.endsWith("/")) {
268+
contextPath = contextPath.substring(0, contextPath.length() -1);
269+
}
270+
return contextPath;
271+
}
272+
273+
public String getContextPath(HttpServletRequest request) {
274+
return this.contextPath;
275+
}
276+
277+
public String getRequestUri(HttpServletRequest request) {
278+
String pathWithinApplication = this.urlPathHelper.getPathWithinApplication(request);
279+
if (this.contextPath.equals("/") && pathWithinApplication.startsWith("/")) {
280+
return pathWithinApplication;
281+
}
282+
return this.contextPath + pathWithinApplication;
283+
}
284+
}
285+
198286
}

spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java

Lines changed: 112 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
*/
1616
package org.springframework.web.filter;
1717

18+
import java.io.IOException;
1819
import javax.servlet.ServletException;
1920
import javax.servlet.http.HttpServlet;
2021
import javax.servlet.http.HttpServletRequest;
2122

23+
import org.junit.Before;
2224
import org.junit.Test;
2325

2426
import org.springframework.mock.web.test.MockFilterChain;
@@ -38,6 +40,98 @@ public class ForwardedHeaderFilterTests {
3840

3941
private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter();
4042

43+
private MockHttpServletRequest request;
44+
45+
private MockFilterChain filterChain;
46+
47+
48+
@Before
49+
public void setUp() throws Exception {
50+
this.request = new MockHttpServletRequest();
51+
this.request.setScheme("http");
52+
this.request.setServerName("localhost");
53+
this.request.setServerPort(80);
54+
this.filterChain = new MockFilterChain(new HttpServlet() {});
55+
}
56+
57+
58+
@Test(expected = IllegalArgumentException.class)
59+
public void contextPathNull() {
60+
this.filter.setContextPath(null);
61+
}
62+
63+
@Test
64+
public void contextPathEmpty() throws Exception {
65+
this.filter.setContextPath("");
66+
assertEquals("", filterAndGetContextPath());
67+
}
68+
69+
@Test
70+
public void contextPathWithExtraSpaces() throws Exception {
71+
this.filter.setContextPath(" /foo ");
72+
assertEquals("/foo", filterAndGetContextPath());
73+
}
74+
75+
@Test
76+
public void contextPathWithNoLeadingSlash() throws Exception {
77+
this.filter.setContextPath("foo");
78+
assertEquals("/foo", filterAndGetContextPath());
79+
}
80+
81+
@Test
82+
public void contextPathWithTrailingSlash() throws Exception {
83+
this.filter.setContextPath("/foo/bar/");
84+
assertEquals("/foo/bar", filterAndGetContextPath());
85+
}
86+
87+
@Test
88+
public void contextPathWithTrailingSlashes() throws Exception {
89+
this.filter.setContextPath("/foo/bar/baz///");
90+
assertEquals("/foo/bar/baz", filterAndGetContextPath());
91+
}
92+
93+
@Test
94+
public void requestUri() throws Exception {
95+
this.filter.setContextPath("/");
96+
this.request.setContextPath("/app");
97+
this.request.setRequestURI("/app/path");
98+
HttpServletRequest actual = filterAndGetWrappedRequest();
99+
100+
assertEquals("/", actual.getContextPath());
101+
assertEquals("/path", actual.getRequestURI());
102+
}
103+
104+
@Test
105+
public void requestUriWithTrailingSlash() throws Exception {
106+
this.filter.setContextPath("/");
107+
this.request.setContextPath("/app");
108+
this.request.setRequestURI("/app/path/");
109+
HttpServletRequest actual = filterAndGetWrappedRequest();
110+
111+
assertEquals("/", actual.getContextPath());
112+
assertEquals("/path/", actual.getRequestURI());
113+
}
114+
@Test
115+
public void requestUriEqualsContextPath() throws Exception {
116+
this.filter.setContextPath("/");
117+
this.request.setContextPath("/app");
118+
this.request.setRequestURI("/app");
119+
HttpServletRequest actual = filterAndGetWrappedRequest();
120+
121+
assertEquals("/", actual.getContextPath());
122+
assertEquals("/", actual.getRequestURI());
123+
}
124+
125+
@Test
126+
public void requestUriRootUrl() throws Exception {
127+
this.filter.setContextPath("/");
128+
this.request.setContextPath("/app");
129+
this.request.setRequestURI("/app/");
130+
HttpServletRequest actual = filterAndGetWrappedRequest();
131+
132+
assertEquals("/", actual.getContextPath());
133+
assertEquals("/", actual.getRequestURI());
134+
}
41135

42136
@Test
43137
public void shouldFilter() throws Exception {
@@ -54,19 +148,14 @@ public void shouldNotFilter() throws Exception {
54148

55149
@Test
56150
public void forwardedRequest() throws Exception {
57-
MockHttpServletRequest request = new MockHttpServletRequest();
58-
request.setScheme("http");
59-
request.setServerName("localhost");
60-
request.setServerPort(80);
61-
request.setRequestURI("/mvc-showcase");
62-
request.addHeader("X-Forwarded-Proto", "https");
63-
request.addHeader("X-Forwarded-Host", "84.198.58.199");
64-
request.addHeader("X-Forwarded-Port", "443");
65-
request.addHeader("foo", "bar");
66-
67-
MockFilterChain chain = new MockFilterChain(new HttpServlet() {});
68-
this.filter.doFilter(request, new MockHttpServletResponse(), chain);
69-
HttpServletRequest actual = (HttpServletRequest) chain.getRequest();
151+
this.request.setRequestURI("/mvc-showcase");
152+
this.request.addHeader("X-Forwarded-Proto", "https");
153+
this.request.addHeader("X-Forwarded-Host", "84.198.58.199");
154+
this.request.addHeader("X-Forwarded-Port", "443");
155+
this.request.addHeader("foo", "bar");
156+
157+
this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
158+
HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest();
70159

71160
assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString());
72161
assertEquals("https", actual.getScheme());
@@ -81,11 +170,20 @@ public void forwardedRequest() throws Exception {
81170
}
82171

83172

173+
private String filterAndGetContextPath() throws ServletException, IOException {
174+
return filterAndGetWrappedRequest().getContextPath();
175+
}
176+
177+
private HttpServletRequest filterAndGetWrappedRequest() throws ServletException, IOException {
178+
MockHttpServletResponse response = new MockHttpServletResponse();
179+
this.filter.doFilterInternal(this.request, response, this.filterChain);
180+
return (HttpServletRequest) this.filterChain.getRequest();
181+
}
182+
84183
private void testShouldFilter(String headerName) throws ServletException {
85184
MockHttpServletRequest request = new MockHttpServletRequest();
86185
request.addHeader(headerName, "1");
87186
assertFalse(this.filter.shouldNotFilter(request));
88187
}
89188

90-
91189
}

0 commit comments

Comments
 (0)