Skip to content

StrictHttpFirewall: Validate headers and parameters #8644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.Map;

import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
Expand All @@ -31,6 +35,7 @@
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

import org.springframework.http.HttpHeaders;
import org.springframework.security.web.util.UrlUtils;

/**
Expand Down Expand Up @@ -161,6 +166,8 @@ class DummyRequest extends HttpServletRequestWrapper {
private String pathInfo;
private String queryString;
private String method;
private final HttpHeaders headers = new HttpHeaders();
private final Map<String, String[]> parameters = new LinkedHashMap<>();

DummyRequest() {
super(UNSUPPORTED_REQUEST);
Expand Down Expand Up @@ -232,6 +239,61 @@ public void setQueryString(String queryString) {
public String getServerName() {
return null;
}

@Override
public String getHeader(String name) {
return this.headers.getFirst(name);
}

@Override
public Enumeration<String> getHeaders(String name) {
return Collections.enumeration(this.headers.get(name));
}

@Override
public Enumeration<String> getHeaderNames() {
return Collections.enumeration(this.headers.keySet());
}

@Override
public int getIntHeader(String name) {
String value = this.headers.getFirst(name);
if (value == null ) {
return -1;
}
else {
return Integer.parseInt(value);
}
}

public void addHeader(String name, String value) {
this.headers.add(name, value);
}

@Override
public String getParameter(String name) {
String[] arr = this.parameters.get(name);
return (arr != null && arr.length > 0 ? arr[0] : null);
}

@Override
public Map<String, String[]> getParameterMap() {
return Collections.unmodifiableMap(this.parameters);
}

@Override
public Enumeration<String> getParameterNames() {
return Collections.enumeration(this.parameters.keySet());
}

@Override
public String[] getParameterValues(String name) {
return this.parameters.get(name);
}

public void setParameter(String name, String... values) {
this.parameters.put(name, values);
}
}

final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

Expand Down Expand Up @@ -74,6 +77,22 @@
* Rejects hosts that are not allowed. See
* {@link #setAllowedHostnames(Predicate)}
* </li>
* <li>
* Reject headers names that are not allowed. See
* {@link #setAllowedHeaderNames(Predicate)}
* </li>
* <li>
* Reject headers values that are not allowed. See
* {@link #setAllowedHeaderValues(Predicate)}
* </li>
* <li>
* Reject parameter names that are not allowed. See
* {@link #setAllowedParameterNames(Predicate)}
* </li>
* <li>
* Reject parameter values that are not allowed. See
* {@link #setAllowedParameterValues(Predicate)}
* </li>
* </ul>
*
* @see DefaultHttpFirewall
Expand Down Expand Up @@ -111,6 +130,18 @@ public class StrictHttpFirewall implements HttpFirewall {

private Predicate<String> allowedHostnames = hostname -> true;

private static final Pattern ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN = Pattern.compile("[\\p{IsAssigned}&&[^\\p{IsControl}]]*");

private static final Predicate<String> ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE = s -> ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN.matcher(s).matches();

private Predicate<String> allowedHeaderNames = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;

private Predicate<String> allowedHeaderValues = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;

private Predicate<String> allowedParameterNames = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;

private Predicate<String> allowedParameterValues = value -> true;

public StrictHttpFirewall() {
urlBlocklistsAddAll(FORBIDDEN_SEMICOLON);
urlBlocklistsAddAll(FORBIDDEN_FORWARDSLASH);
Expand Down Expand Up @@ -330,6 +361,77 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
}
}

/**
* <p>
* Determines which header names should be allowed.
* The default is to reject header names that contain ISO control characters
* and characters that are not defined.
* </p>
*
* @param allowedHeaderNames the predicate for testing header names
* @see Character#isISOControl(int)
* @see Character#isDefined(int)
* @since 5.4
*/
public void setAllowedHeaderNames(Predicate<String> allowedHeaderNames) {
if (allowedHeaderNames == null) {
throw new IllegalArgumentException("allowedHeaderNames cannot be null");
}
this.allowedHeaderNames = allowedHeaderNames;
}

/**
* <p>
* Determines which header values should be allowed.
* The default is to reject header values that contain ISO control characters
* and characters that are not defined.
* </p>
*
* @param allowedHeaderValues the predicate for testing hostnames
* @see Character#isISOControl(int)
* @see Character#isDefined(int)
* @since 5.4
*/
public void setAllowedHeaderValues(Predicate<String> allowedHeaderValues) {
if (allowedHeaderValues == null) {
throw new IllegalArgumentException("allowedHeaderValues cannot be null");
}
this.allowedHeaderValues = allowedHeaderValues;
}
/*
* Determines which parameter names should be allowed.
* The default is to reject header names that contain ISO control characters
* and characters that are not defined.
* </p>
*
* @param allowedParameterNames the predicate for testing parameter names
* @see Character#isISOControl(int)
* @see Character#isDefined(int)
* @since 5.4
*/
public void setAllowedParameterNames(Predicate<String> allowedParameterNames) {
if (allowedParameterNames == null) {
throw new IllegalArgumentException("allowedParameterNames cannot be null");
}
this.allowedParameterNames = allowedParameterNames;
}

/**
* <p>
* Determines which parameter values should be allowed.
* The default is to allow any parameter value.
* </p>
*
* @param allowedParameterValues the predicate for testing parameter values
* @since 5.4
*/
public void setAllowedParameterValues(Predicate<String> allowedParameterValues) {
if (allowedParameterValues == null) {
throw new IllegalArgumentException("allowedParameterValues cannot be null");
}
this.allowedParameterValues = allowedParameterValues;
}

/**
* <p>
* Determines which hostnames should be allowed. The default is to allow any hostname.
Expand Down Expand Up @@ -370,6 +472,144 @@ public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws
throw new RequestRejectedException("The requestURI was rejected because it can only contain printable ASCII characters.");
}
return new FirewalledRequest(request) {
@Override
public long getDateHeader(String name) {
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return super.getDateHeader(name);
}

@Override
public int getIntHeader(String name) {
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return super.getIntHeader(name);
}

@Override
public String getHeader(String name) {
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
String value = super.getHeader(name);
if (value != null && !allowedHeaderValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the header value \"" + value + "\" is not allowed.");
}
return value;
}

@Override
public Enumeration<String> getHeaders(String name) {
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}

Enumeration<String> valuesEnumeration = super.getHeaders(name);
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return valuesEnumeration.hasMoreElements();
}

@Override
public String nextElement() {
String value = valuesEnumeration.nextElement();
if (!allowedHeaderValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the header value \"" + value + "\" is not allowed.");
}
return value;
}
};
}

@Override
public Enumeration<String> getHeaderNames() {
Enumeration<String> namesEnumeration = super.getHeaderNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return namesEnumeration.hasMoreElements();
}

@Override
public String nextElement() {
String name = namesEnumeration.nextElement();
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return name;
}
};
}

@Override
public String getParameter(String name) {
if (!allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
String value = super.getParameter(name);
if (value != null && !allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
return value;
}

@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> parameterMap = super.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
String name = entry.getKey();
String[] values = entry.getValue();
if (!allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
for (String value: values) {
if (!allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
}
}
return parameterMap;
}

@Override
public Enumeration<String> getParameterNames() {
Enumeration<String> namesEnumeration = super.getParameterNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return namesEnumeration.hasMoreElements();
}

@Override
public String nextElement() {
String name = namesEnumeration.nextElement();
if (!allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
return name;
}
};
}

@Override
public String[] getParameterValues(String name) {
if (!allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
String[] values = super.getParameterValues(name);
if (values != null) {
for (String value: values) {
if (!allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
}
}
return values;
}

@Override
public void reset() {
}
Expand Down
Loading