Validate headers and parameters in StrictHttpFirewall

Adds methods to configure validation of header names and values and
parameter names and values:
 * setAllowedHeaderNames(Predicate)
 * setAllowedHeaderValues(Predicate)
 * setAllowedParameterNames(Predicate)
 * setAllowedParameterValues(Predicate)

By default, header names, header values, and parameter names that
contain ISO control characters or unassigned unicode characters are
rejected. No parameter value validation is performed by default.

Issue gh-8644
This commit is contained in:
Craig Andrews 2020-06-03 17:31:48 -04:00 committed by Josh Cummings
parent 88028d82ed
commit c71352c548
3 changed files with 445 additions and 0 deletions

View File

@ -23,6 +23,10 @@ import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
@ -31,6 +35,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.security.web.util.UrlUtils;
/**
@ -161,6 +166,8 @@ class DummyRequest extends HttpServletRequestWrapper {
private String pathInfo;
private String queryString;
private String method;
private final HttpHeaders headers = new HttpHeaders();
private final Map<String, String[]> parameters = new LinkedHashMap<>();
DummyRequest() {
super(UNSUPPORTED_REQUEST);
@ -232,6 +239,61 @@ class DummyRequest extends HttpServletRequestWrapper {
public String getServerName() {
return null;
}
@Override
public String getHeader(String name) {
return this.headers.getFirst(name);
}
@Override
public Enumeration<String> getHeaders(String name) {
return Collections.enumeration(this.headers.get(name));
}
@Override
public Enumeration<String> getHeaderNames() {
return Collections.enumeration(this.headers.keySet());
}
@Override
public int getIntHeader(String name) {
String value = this.headers.getFirst(name);
if (value == null ) {
return -1;
}
else {
return Integer.parseInt(value);
}
}
public void addHeader(String name, String value) {
this.headers.add(name, value);
}
@Override
public String getParameter(String name) {
String[] arr = this.parameters.get(name);
return (arr != null && arr.length > 0 ? arr[0] : null);
}
@Override
public Map<String, String[]> getParameterMap() {
return Collections.unmodifiableMap(this.parameters);
}
@Override
public Enumeration<String> getParameterNames() {
return Collections.enumeration(this.parameters.keySet());
}
@Override
public String[] getParameterValues(String name) {
return this.parameters.get(name);
}
public void setParameter(String name, String... values) {
this.parameters.put(name, values);
}
}
final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {

View File

@ -19,10 +19,13 @@ package org.springframework.security.web.firewall;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -74,6 +77,22 @@ import org.springframework.http.HttpMethod;
* Rejects hosts that are not allowed. See
* {@link #setAllowedHostnames(Predicate)}
* </li>
* <li>
* Reject headers names that are not allowed. See
* {@link #setAllowedHeaderNames(Predicate)}
* </li>
* <li>
* Reject headers values that are not allowed. See
* {@link #setAllowedHeaderValues(Predicate)}
* </li>
* <li>
* Reject parameter names that are not allowed. See
* {@link #setAllowedParameterNames(Predicate)}
* </li>
* <li>
* Reject parameter values that are not allowed. See
* {@link #setAllowedParameterValues(Predicate)}
* </li>
* </ul>
*
* @see DefaultHttpFirewall
@ -111,6 +130,18 @@ public class StrictHttpFirewall implements HttpFirewall {
private Predicate<String> allowedHostnames = hostname -> true;
private static final Pattern ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN = Pattern.compile("[\\p{IsAssigned}&&[^\\p{IsControl}]]*");
private static final Predicate<String> ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE = s -> ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN.matcher(s).matches();
private Predicate<String> allowedHeaderNames = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
private Predicate<String> allowedHeaderValues = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
private Predicate<String> allowedParameterNames = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
private Predicate<String> allowedParameterValues = value -> true;
public StrictHttpFirewall() {
urlBlocklistsAddAll(FORBIDDEN_SEMICOLON);
urlBlocklistsAddAll(FORBIDDEN_FORWARDSLASH);
@ -330,6 +361,77 @@ public class StrictHttpFirewall implements HttpFirewall {
}
}
/**
* <p>
* Determines which header names should be allowed.
* The default is to reject header names that contain ISO control characters
* and characters that are not defined.
* </p>
*
* @param allowedHeaderNames the predicate for testing header names
* @see Character#isISOControl(int)
* @see Character#isDefined(int)
* @since 5.4
*/
public void setAllowedHeaderNames(Predicate<String> allowedHeaderNames) {
if (allowedHeaderNames == null) {
throw new IllegalArgumentException("allowedHeaderNames cannot be null");
}
this.allowedHeaderNames = allowedHeaderNames;
}
/**
* <p>
* Determines which header values should be allowed.
* The default is to reject header values that contain ISO control characters
* and characters that are not defined.
* </p>
*
* @param allowedHeaderValues the predicate for testing hostnames
* @see Character#isISOControl(int)
* @see Character#isDefined(int)
* @since 5.4
*/
public void setAllowedHeaderValues(Predicate<String> allowedHeaderValues) {
if (allowedHeaderValues == null) {
throw new IllegalArgumentException("allowedHeaderValues cannot be null");
}
this.allowedHeaderValues = allowedHeaderValues;
}
/*
* Determines which parameter names should be allowed.
* The default is to reject header names that contain ISO control characters
* and characters that are not defined.
* </p>
*
* @param allowedParameterNames the predicate for testing parameter names
* @see Character#isISOControl(int)
* @see Character#isDefined(int)
* @since 5.4
*/
public void setAllowedParameterNames(Predicate<String> allowedParameterNames) {
if (allowedParameterNames == null) {
throw new IllegalArgumentException("allowedParameterNames cannot be null");
}
this.allowedParameterNames = allowedParameterNames;
}
/**
* <p>
* Determines which parameter values should be allowed.
* The default is to allow any parameter value.
* </p>
*
* @param allowedParameterValues the predicate for testing parameter values
* @since 5.4
*/
public void setAllowedParameterValues(Predicate<String> allowedParameterValues) {
if (allowedParameterValues == null) {
throw new IllegalArgumentException("allowedParameterValues cannot be null");
}
this.allowedParameterValues = allowedParameterValues;
}
/**
* <p>
* Determines which hostnames should be allowed. The default is to allow any hostname.
@ -370,6 +472,144 @@ public class StrictHttpFirewall implements HttpFirewall {
throw new RequestRejectedException("The requestURI was rejected because it can only contain printable ASCII characters.");
}
return new FirewalledRequest(request) {
@Override
public long getDateHeader(String name) {
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return super.getDateHeader(name);
}
@Override
public int getIntHeader(String name) {
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return super.getIntHeader(name);
}
@Override
public String getHeader(String name) {
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
String value = super.getHeader(name);
if (value != null && !allowedHeaderValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the header value \"" + value + "\" is not allowed.");
}
return value;
}
@Override
public Enumeration<String> getHeaders(String name) {
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
Enumeration<String> valuesEnumeration = super.getHeaders(name);
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return valuesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String value = valuesEnumeration.nextElement();
if (!allowedHeaderValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the header value \"" + value + "\" is not allowed.");
}
return value;
}
};
}
@Override
public Enumeration<String> getHeaderNames() {
Enumeration<String> namesEnumeration = super.getHeaderNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return namesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String name = namesEnumeration.nextElement();
if (!allowedHeaderNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
}
return name;
}
};
}
@Override
public String getParameter(String name) {
if (!allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
String value = super.getParameter(name);
if (value != null && !allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
return value;
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> parameterMap = super.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
String name = entry.getKey();
String[] values = entry.getValue();
if (!allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
for (String value: values) {
if (!allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
}
}
return parameterMap;
}
@Override
public Enumeration<String> getParameterNames() {
Enumeration<String> namesEnumeration = super.getParameterNames();
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return namesEnumeration.hasMoreElements();
}
@Override
public String nextElement() {
String name = namesEnumeration.nextElement();
if (!allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
return name;
}
};
}
@Override
public String[] getParameterValues(String name) {
if (!allowedParameterNames.test(name)) {
throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
}
String[] values = super.getParameterValues(name);
if (values != null) {
for (String value: values) {
if (!allowedParameterValues.test(value)) {
throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
}
}
}
return values;
}
@Override
public void reset() {
}

View File

@ -23,6 +23,8 @@ import static org.assertj.core.api.Assertions.fail;
import java.util.Arrays;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import org.junit.Test;
import org.springframework.http.HttpMethod;
import org.springframework.mock.web.MockHttpServletRequest;
@ -595,4 +597,145 @@ public class StrictHttpFirewallTests {
this.firewall.getFirewalledRequest(this.request);
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeaderWhenNotAllowedHeaderNameThenException() {
this.firewall.setAllowedHeaderNames(name -> !name.equals("bad name"));
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeader("bad name");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeaderWhenNotAllowedHeaderValueThenException() {
this.request.addHeader("good name", "bad value");
this.firewall.setAllowedHeaderValues(value -> !value.equals("bad value"));
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeader("good name");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetDateHeaderWhenControlCharacterInHeaderNameThenException() {
this.request.addHeader("Bad\0Name", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getDateHeader("Bad\0Name");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetIntHeaderWhenControlCharacterInHeaderNameThenException() {
this.request.addHeader("Bad\0Name", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getIntHeader("Bad\0Name");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeaderWhenControlCharacterInHeaderNameThenException() {
this.request.addHeader("Bad\0Name", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeader("Bad\0Name");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeaderWhenUndefinedCharacterInHeaderNameThenException() {
this.request.addHeader("Bad\uFFFEName", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeader("Bad\uFFFEName");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeadersWhenControlCharacterInHeaderNameThenException() {
this.request.addHeader("Bad\0Name", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeaders("Bad\0Name");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeaderNamesWhenControlCharacterInHeaderNameThenException() {
this.request.addHeader("Bad\0Name", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeaderNames().nextElement();
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeaderWhenControlCharacterInHeaderValueThenException() {
this.request.addHeader("Something", "bad\0value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeader("Something");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeaderWhenUndefinedCharacterInHeaderValueThenException() {
this.request.addHeader("Something", "bad\uFFFEvalue");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeader("Something");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetHeadersWhenControlCharacterInHeaderValueThenException() {
this.request.addHeader("Something", "bad\0value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getHeaders("Something").nextElement();
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetParameterWhenControlCharacterInParameterNameThenException() {
this.request.addParameter("Bad\0Name", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getParameter("Bad\0Name");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetParameterMapWhenControlCharacterInParameterNameThenException() {
this.request.addParameter("Bad\0Name", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getParameterMap();
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetParameterNamesWhenControlCharacterInParameterNameThenException() {
this.request.addParameter("Bad\0Name", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getParameterNames().nextElement();
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetParameterNamesWhenUndefinedCharacterInParameterNameThenException() {
this.request.addParameter("Bad\uFFFEName", "some value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getParameterNames().nextElement();
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetParameterValuesWhenNotAllowedInParameterValueThenException() {
this.firewall.setAllowedParameterValues(value -> !value.equals("bad value"));
this.request.addParameter("Something", "bad value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getParameterValues("Something");
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestGetParameterValuesWhenNotAllowedInParameterNameThenException() {
this.firewall.setAllowedParameterNames(value -> !value.equals("bad name"));
this.request.addParameter("bad name", "good value");
HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
request.getParameterValues("bad name");
}
}