diff --git a/web/src/main/java/org/springframework/security/web/FilterInvocation.java b/web/src/main/java/org/springframework/security/web/FilterInvocation.java
index 29b0407401d..97061e1872e 100644
--- a/web/src/main/java/org/springframework/security/web/FilterInvocation.java
+++ b/web/src/main/java/org/springframework/security/web/FilterInvocation.java
@@ -23,6 +23,10 @@
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.LinkedHashMap;
+import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
@@ -31,6 +35,7 @@
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
+import org.springframework.http.HttpHeaders;
import org.springframework.security.web.util.UrlUtils;
/**
@@ -161,6 +166,8 @@ class DummyRequest extends HttpServletRequestWrapper {
private String pathInfo;
private String queryString;
private String method;
+ private final HttpHeaders headers = new HttpHeaders();
+ private final Map parameters = new LinkedHashMap<>();
DummyRequest() {
super(UNSUPPORTED_REQUEST);
@@ -232,6 +239,61 @@ public void setQueryString(String queryString) {
public String getServerName() {
return null;
}
+
+ @Override
+ public String getHeader(String name) {
+ return this.headers.getFirst(name);
+ }
+
+ @Override
+ public Enumeration getHeaders(String name) {
+ return Collections.enumeration(this.headers.get(name));
+ }
+
+ @Override
+ public Enumeration getHeaderNames() {
+ return Collections.enumeration(this.headers.keySet());
+ }
+
+ @Override
+ public int getIntHeader(String name) {
+ String value = this.headers.getFirst(name);
+ if (value == null ) {
+ return -1;
+ }
+ else {
+ return Integer.parseInt(value);
+ }
+ }
+
+ public void addHeader(String name, String value) {
+ this.headers.add(name, value);
+ }
+
+ @Override
+ public String getParameter(String name) {
+ String[] arr = this.parameters.get(name);
+ return (arr != null && arr.length > 0 ? arr[0] : null);
+ }
+
+ @Override
+ public Map getParameterMap() {
+ return Collections.unmodifiableMap(this.parameters);
+ }
+
+ @Override
+ public Enumeration getParameterNames() {
+ return Collections.enumeration(this.parameters.keySet());
+ }
+
+ @Override
+ public String[] getParameterValues(String name) {
+ return this.parameters.get(name);
+ }
+
+ public void setParameter(String name, String... values) {
+ this.parameters.put(name, values);
+ }
}
final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {
diff --git a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java
index 8375846aeb7..0f12748e07f 100644
--- a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java
+++ b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java
@@ -19,10 +19,13 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
+import java.util.regex.Pattern;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@@ -74,6 +77,22 @@
* Rejects hosts that are not allowed. See
* {@link #setAllowedHostnames(Predicate)}
*
+ *
+ * Reject headers names that are not allowed. See
+ * {@link #setAllowedHeaderNames(Predicate)}
+ *
+ *
+ * Reject headers values that are not allowed. See
+ * {@link #setAllowedHeaderValues(Predicate)}
+ *
+ *
+ * Reject parameter names that are not allowed. See
+ * {@link #setAllowedParameterNames(Predicate)}
+ *
+ *
+ * Reject parameter values that are not allowed. See
+ * {@link #setAllowedParameterValues(Predicate)}
+ *
*
*
* @see DefaultHttpFirewall
@@ -111,6 +130,18 @@ public class StrictHttpFirewall implements HttpFirewall {
private Predicate allowedHostnames = hostname -> true;
+ private static final Pattern ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN = Pattern.compile("[\\p{IsAssigned}&&[^\\p{IsControl}]]*");
+
+ private static final Predicate ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE = s -> ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN.matcher(s).matches();
+
+ private Predicate allowedHeaderNames = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
+
+ private Predicate allowedHeaderValues = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
+
+ private Predicate allowedParameterNames = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
+
+ private Predicate allowedParameterValues = value -> true;
+
public StrictHttpFirewall() {
urlBlocklistsAddAll(FORBIDDEN_SEMICOLON);
urlBlocklistsAddAll(FORBIDDEN_FORWARDSLASH);
@@ -330,6 +361,77 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
}
}
+ /**
+ *
+ * Determines which header names should be allowed.
+ * The default is to reject header names that contain ISO control characters
+ * and characters that are not defined.
+ *
+ *
+ * @param allowedHeaderNames the predicate for testing header names
+ * @see Character#isISOControl(int)
+ * @see Character#isDefined(int)
+ * @since 5.4
+ */
+ public void setAllowedHeaderNames(Predicate allowedHeaderNames) {
+ if (allowedHeaderNames == null) {
+ throw new IllegalArgumentException("allowedHeaderNames cannot be null");
+ }
+ this.allowedHeaderNames = allowedHeaderNames;
+ }
+
+ /**
+ *
+ * Determines which header values should be allowed.
+ * The default is to reject header values that contain ISO control characters
+ * and characters that are not defined.
+ *
+ *
+ * @param allowedHeaderValues the predicate for testing hostnames
+ * @see Character#isISOControl(int)
+ * @see Character#isDefined(int)
+ * @since 5.4
+ */
+ public void setAllowedHeaderValues(Predicate allowedHeaderValues) {
+ if (allowedHeaderValues == null) {
+ throw new IllegalArgumentException("allowedHeaderValues cannot be null");
+ }
+ this.allowedHeaderValues = allowedHeaderValues;
+ }
+ /*
+ * Determines which parameter names should be allowed.
+ * The default is to reject header names that contain ISO control characters
+ * and characters that are not defined.
+ *
+ *
+ * @param allowedParameterNames the predicate for testing parameter names
+ * @see Character#isISOControl(int)
+ * @see Character#isDefined(int)
+ * @since 5.4
+ */
+ public void setAllowedParameterNames(Predicate allowedParameterNames) {
+ if (allowedParameterNames == null) {
+ throw new IllegalArgumentException("allowedParameterNames cannot be null");
+ }
+ this.allowedParameterNames = allowedParameterNames;
+ }
+
+ /**
+ *
+ * Determines which parameter values should be allowed.
+ * The default is to allow any parameter value.
+ *
+ *
+ * @param allowedParameterValues the predicate for testing parameter values
+ * @since 5.4
+ */
+ public void setAllowedParameterValues(Predicate allowedParameterValues) {
+ if (allowedParameterValues == null) {
+ throw new IllegalArgumentException("allowedParameterValues cannot be null");
+ }
+ this.allowedParameterValues = allowedParameterValues;
+ }
+
/**
*
* Determines which hostnames should be allowed. The default is to allow any hostname.
@@ -370,6 +472,144 @@ public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws
throw new RequestRejectedException("The requestURI was rejected because it can only contain printable ASCII characters.");
}
return new FirewalledRequest(request) {
+ @Override
+ public long getDateHeader(String name) {
+ if (!allowedHeaderNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+ }
+ return super.getDateHeader(name);
+ }
+
+ @Override
+ public int getIntHeader(String name) {
+ if (!allowedHeaderNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+ }
+ return super.getIntHeader(name);
+ }
+
+ @Override
+ public String getHeader(String name) {
+ if (!allowedHeaderNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+ }
+ String value = super.getHeader(name);
+ if (value != null && !allowedHeaderValues.test(value)) {
+ throw new RequestRejectedException("The request was rejected because the header value \"" + value + "\" is not allowed.");
+ }
+ return value;
+ }
+
+ @Override
+ public Enumeration getHeaders(String name) {
+ if (!allowedHeaderNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+ }
+
+ Enumeration valuesEnumeration = super.getHeaders(name);
+ return new Enumeration() {
+ @Override
+ public boolean hasMoreElements() {
+ return valuesEnumeration.hasMoreElements();
+ }
+
+ @Override
+ public String nextElement() {
+ String value = valuesEnumeration.nextElement();
+ if (!allowedHeaderValues.test(value)) {
+ throw new RequestRejectedException("The request was rejected because the header value \"" + value + "\" is not allowed.");
+ }
+ return value;
+ }
+ };
+ }
+
+ @Override
+ public Enumeration getHeaderNames() {
+ Enumeration namesEnumeration = super.getHeaderNames();
+ return new Enumeration() {
+ @Override
+ public boolean hasMoreElements() {
+ return namesEnumeration.hasMoreElements();
+ }
+
+ @Override
+ public String nextElement() {
+ String name = namesEnumeration.nextElement();
+ if (!allowedHeaderNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+ }
+ return name;
+ }
+ };
+ }
+
+ @Override
+ public String getParameter(String name) {
+ if (!allowedParameterNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
+ }
+ String value = super.getParameter(name);
+ if (value != null && !allowedParameterValues.test(value)) {
+ throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
+ }
+ return value;
+ }
+
+ @Override
+ public Map getParameterMap() {
+ Map parameterMap = super.getParameterMap();
+ for (Map.Entry entry : parameterMap.entrySet()) {
+ String name = entry.getKey();
+ String[] values = entry.getValue();
+ if (!allowedParameterNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
+ }
+ for (String value: values) {
+ if (!allowedParameterValues.test(value)) {
+ throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
+ }
+ }
+ }
+ return parameterMap;
+ }
+
+ @Override
+ public Enumeration getParameterNames() {
+ Enumeration namesEnumeration = super.getParameterNames();
+ return new Enumeration() {
+ @Override
+ public boolean hasMoreElements() {
+ return namesEnumeration.hasMoreElements();
+ }
+
+ @Override
+ public String nextElement() {
+ String name = namesEnumeration.nextElement();
+ if (!allowedParameterNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
+ }
+ return name;
+ }
+ };
+ }
+
+ @Override
+ public String[] getParameterValues(String name) {
+ if (!allowedParameterNames.test(name)) {
+ throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
+ }
+ String[] values = super.getParameterValues(name);
+ if (values != null) {
+ for (String value: values) {
+ if (!allowedParameterValues.test(value)) {
+ throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
+ }
+ }
+ }
+ return values;
+ }
+
@Override
public void reset() {
}
diff --git a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java
index 89971b8355b..68dbddf5891 100644
--- a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java
+++ b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java
@@ -23,6 +23,8 @@
import java.util.Arrays;
import java.util.List;
+import javax.servlet.http.HttpServletRequest;
+
import org.junit.Test;
import org.springframework.http.HttpMethod;
import org.springframework.mock.web.MockHttpServletRequest;
@@ -595,4 +597,145 @@ public void getFirewalledRequestWhenUntrustedDomainThenException() {
this.firewall.getFirewalledRequest(this.request);
}
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeaderWhenNotAllowedHeaderNameThenException() {
+ this.firewall.setAllowedHeaderNames(name -> !name.equals("bad name"));
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeader("bad name");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeaderWhenNotAllowedHeaderValueThenException() {
+ this.request.addHeader("good name", "bad value");
+ this.firewall.setAllowedHeaderValues(value -> !value.equals("bad value"));
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeader("good name");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetDateHeaderWhenControlCharacterInHeaderNameThenException() {
+ this.request.addHeader("Bad\0Name", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getDateHeader("Bad\0Name");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetIntHeaderWhenControlCharacterInHeaderNameThenException() {
+ this.request.addHeader("Bad\0Name", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getIntHeader("Bad\0Name");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeaderWhenControlCharacterInHeaderNameThenException() {
+ this.request.addHeader("Bad\0Name", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeader("Bad\0Name");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeaderWhenUndefinedCharacterInHeaderNameThenException() {
+ this.request.addHeader("Bad\uFFFEName", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeader("Bad\uFFFEName");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeadersWhenControlCharacterInHeaderNameThenException() {
+ this.request.addHeader("Bad\0Name", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeaders("Bad\0Name");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeaderNamesWhenControlCharacterInHeaderNameThenException() {
+ this.request.addHeader("Bad\0Name", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeaderNames().nextElement();
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeaderWhenControlCharacterInHeaderValueThenException() {
+ this.request.addHeader("Something", "bad\0value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeader("Something");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeaderWhenUndefinedCharacterInHeaderValueThenException() {
+ this.request.addHeader("Something", "bad\uFFFEvalue");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeader("Something");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetHeadersWhenControlCharacterInHeaderValueThenException() {
+ this.request.addHeader("Something", "bad\0value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getHeaders("Something").nextElement();
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetParameterWhenControlCharacterInParameterNameThenException() {
+ this.request.addParameter("Bad\0Name", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getParameter("Bad\0Name");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetParameterMapWhenControlCharacterInParameterNameThenException() {
+ this.request.addParameter("Bad\0Name", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getParameterMap();
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetParameterNamesWhenControlCharacterInParameterNameThenException() {
+ this.request.addParameter("Bad\0Name", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getParameterNames().nextElement();
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetParameterNamesWhenUndefinedCharacterInParameterNameThenException() {
+ this.request.addParameter("Bad\uFFFEName", "some value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getParameterNames().nextElement();
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetParameterValuesWhenNotAllowedInParameterValueThenException() {
+ this.firewall.setAllowedParameterValues(value -> !value.equals("bad value"));
+
+ this.request.addParameter("Something", "bad value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getParameterValues("Something");
+ }
+
+ @Test(expected = RequestRejectedException.class)
+ public void getFirewalledRequestGetParameterValuesWhenNotAllowedInParameterNameThenException() {
+ this.firewall.setAllowedParameterNames(value -> !value.equals("bad name"));
+
+ this.request.addParameter("bad name", "good value");
+
+ HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+ request.getParameterValues("bad name");
+ }
}