Skip to content

Commit 9c7de23

Browse files
committed
Polishing
Optimize same origin check when the request is an instance of ServletServerHttpRequest and when there is no forwarded headers. This commit also optimizes the getPort methods and ForwardedHeaderFilter forwarded headers checks. Issue: SPR-16262
1 parent c326e44 commit 9c7de23

File tree

5 files changed

+76
-40
lines changed

5 files changed

+76
-40
lines changed

spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,19 @@ public static boolean isSameOrigin(ServerHttpRequest request) {
6464
UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
6565
UriComponents actualUrl = urlBuilder.build();
6666
String actualHost = actualUrl.getHost();
67-
int actualPort = getPort(actualUrl);
67+
int actualPort = getPort(actualUrl.getScheme(), actualUrl.getPort());
6868
Assert.notNull(actualHost, "Actual request host must not be null");
6969
Assert.isTrue(actualPort != -1, "Actual request port must not be undefined");
7070
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
71-
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl));
71+
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl.getScheme(), originUrl.getPort()));
7272
}
7373

74-
private static int getPort(UriComponents uri) {
75-
int port = uri.getPort();
74+
private static int getPort(String scheme, int port) {
7675
if (port == -1) {
77-
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
76+
if ("http".equals(scheme) || "ws".equals(scheme)) {
7877
port = 80;
7978
}
80-
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
79+
else if ("https".equals(scheme) || "wss".equals(scheme)) {
8180
port = 443;
8281
}
8382
}

spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,8 @@ public void setRelativeRedirects(boolean relativeRedirects) {
118118

119119
@Override
120120
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
121-
Enumeration<String> names = request.getHeaderNames();
122-
while (names.hasMoreElements()) {
123-
String name = names.nextElement();
124-
if (FORWARDED_HEADER_NAMES.contains(name)) {
121+
for (String headerName : FORWARDED_HEADER_NAMES) {
122+
if (request.getHeader(headerName) != null) {
125123
return false;
126124
}
127125
}

spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@
1717
package org.springframework.web.filter.reactive;
1818

1919
import java.net.URI;
20-
import java.util.Collections;
21-
import java.util.Locale;
20+
import java.util.LinkedHashSet;
2221
import java.util.Set;
2322

2423
import reactor.core.publisher.Mono;
2524

2625
import org.springframework.http.HttpHeaders;
2726
import org.springframework.http.server.reactive.ServerHttpRequest;
2827
import org.springframework.lang.Nullable;
29-
import org.springframework.util.LinkedCaseInsensitiveMap;
3028
import org.springframework.web.server.ServerWebExchange;
3129
import org.springframework.web.server.WebFilter;
3230
import org.springframework.web.server.WebFilterChain;
@@ -47,8 +45,7 @@
4745
*/
4846
public class ForwardedHeaderFilter implements WebFilter {
4947

50-
private static final Set<String> FORWARDED_HEADER_NAMES =
51-
Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH));
48+
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
5249

5350
static {
5451
FORWARDED_HEADER_NAMES.add("Forwarded");
@@ -104,8 +101,13 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
104101
}
105102

106103
private boolean shouldNotFilter(ServerHttpRequest request) {
107-
return request.getHeaders().keySet().stream()
108-
.noneMatch(FORWARDED_HEADER_NAMES::contains);
104+
HttpHeaders headers = request.getHeaders();
105+
for (String headerName : FORWARDED_HEADER_NAMES) {
106+
if (headers.containsKey(headerName)) {
107+
return false;
108+
}
109+
}
110+
return true;
109111
}
110112

111113
@Nullable

spring-web/src/main/java/org/springframework/web/util/WebUtils.java

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import java.io.FileNotFoundException;
2121
import java.util.Collection;
2222
import java.util.Enumeration;
23+
import java.util.LinkedHashSet;
2324
import java.util.Map;
25+
import java.util.Set;
2426
import java.util.StringTokenizer;
2527
import java.util.TreeMap;
2628
import javax.servlet.ServletContext;
@@ -33,6 +35,7 @@
3335
import javax.servlet.http.HttpServletResponse;
3436
import javax.servlet.http.HttpSession;
3537

38+
import org.springframework.http.HttpHeaders;
3639
import org.springframework.http.HttpRequest;
3740
import org.springframework.http.server.ServletServerHttpRequest;
3841
import org.springframework.lang.Nullable;
@@ -135,6 +138,16 @@ public abstract class WebUtils {
135138
/** Key for the mutex session attribute */
136139
public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX";
137140

141+
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
142+
143+
static {
144+
FORWARDED_HEADER_NAMES.add("Forwarded");
145+
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
146+
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
147+
FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
148+
FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix");
149+
}
150+
138151

139152
/**
140153
* Set a system property to the web application root directory.
@@ -693,36 +706,60 @@ else if (CollectionUtils.isEmpty(allowedOrigins)) {
693706
* @since 4.2
694707
*/
695708
public static boolean isSameOrigin(HttpRequest request) {
696-
String origin = request.getHeaders().getOrigin();
709+
HttpHeaders headers = request.getHeaders();
710+
String origin = headers.getOrigin();
697711
if (origin == null) {
698712
return true;
699713
}
700-
UriComponentsBuilder urlBuilder;
714+
String scheme;
715+
String host;
716+
int port;
701717
if (request instanceof ServletServerHttpRequest) {
702718
// Build more efficiently if we can: we only need scheme, host, port for origin comparison
703719
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
704-
urlBuilder = new UriComponentsBuilder().
705-
scheme(servletRequest.getScheme()).
706-
host(servletRequest.getServerName()).
707-
port(servletRequest.getServerPort()).
708-
adaptFromForwardedHeaders(request.getHeaders());
720+
scheme = servletRequest.getScheme();
721+
host = servletRequest.getServerName();
722+
port = servletRequest.getServerPort();
723+
724+
if(containsForwardedHeaders(servletRequest)) {
725+
UriComponents actualUrl = new UriComponentsBuilder()
726+
.scheme(scheme)
727+
.host(host)
728+
.port(port)
729+
.adaptFromForwardedHeaders(headers)
730+
.build();
731+
scheme = actualUrl.getScheme();
732+
host = actualUrl.getHost();
733+
port = actualUrl.getPort();
734+
}
709735
}
710736
else {
711-
urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
737+
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
738+
scheme = actualUrl.getScheme();
739+
host = actualUrl.getHost();
740+
port = actualUrl.getPort();
712741
}
713-
UriComponents actualUrl = urlBuilder.build();
742+
714743
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
715-
return (ObjectUtils.nullSafeEquals(actualUrl.getHost(), originUrl.getHost()) &&
716-
getPort(actualUrl) == getPort(originUrl));
744+
return (ObjectUtils.nullSafeEquals(host, originUrl.getHost()) &&
745+
getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort()));
746+
}
747+
748+
private static boolean containsForwardedHeaders(HttpServletRequest request) {
749+
for (String headerName : FORWARDED_HEADER_NAMES) {
750+
if (request.getHeader(headerName) != null) {
751+
return true;
752+
}
753+
}
754+
return false;
717755
}
718756

719-
private static int getPort(UriComponents uri) {
720-
int port = uri.getPort();
757+
private static int getPort(String scheme, int port) {
721758
if (port == -1) {
722-
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
759+
if ("http".equals(scheme) || "ws".equals(scheme)) {
723760
port = 80;
724761
}
725-
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
762+
else if ("https".equals(scheme) || "wss".equals(scheme)) {
726763
port = 443;
727764
}
728765
}

spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ private boolean checkValidOrigin(String serverName, int port, String originHeade
168168
if (port != -1) {
169169
servletRequest.setServerPort(port);
170170
}
171-
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
171+
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
172172
return WebUtils.isValidOrigin(request, allowed);
173173
}
174174

@@ -179,7 +179,7 @@ private boolean checkSameOrigin(String serverName, int port, String originHeader
179179
if (port != -1) {
180180
servletRequest.setServerPort(port);
181181
}
182-
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
182+
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
183183
return WebUtils.isSameOrigin(request);
184184
}
185185

@@ -191,15 +191,15 @@ private boolean checkSameOriginWithXForwardedHeaders(String serverName, int port
191191
servletRequest.setServerPort(port);
192192
}
193193
if (forwardedProto != null) {
194-
request.getHeaders().set("X-Forwarded-Proto", forwardedProto);
194+
servletRequest.addHeader("X-Forwarded-Proto", forwardedProto);
195195
}
196196
if (forwardedHost != null) {
197-
request.getHeaders().set("X-Forwarded-Host", forwardedHost);
197+
servletRequest.addHeader("X-Forwarded-Host", forwardedHost);
198198
}
199199
if (forwardedPort != -1) {
200-
request.getHeaders().set("X-Forwarded-Port", String.valueOf(forwardedPort));
200+
servletRequest.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort));
201201
}
202-
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
202+
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
203203
return WebUtils.isSameOrigin(request);
204204
}
205205

@@ -210,8 +210,8 @@ private boolean checkSameOriginWithForwardedHeader(String serverName, int port,
210210
if (port != -1) {
211211
servletRequest.setServerPort(port);
212212
}
213-
request.getHeaders().set("Forwarded", forwardedHeader);
214-
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
213+
servletRequest.addHeader("Forwarded", forwardedHeader);
214+
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
215215
return WebUtils.isSameOrigin(request);
216216
}
217217

0 commit comments

Comments
 (0)