Add support for allowedHostnames in StrictHttpFirewall

Introduce a new method `setAllowedHostnames` which perform the validation
against untrusted hostnames.

Fixes gh-4310
This commit is contained in:
Eddú Meléndez 2019-07-26 21:56:44 -05:00 committed by Josh Cummings
parent ded83cc1b3
commit 52c80c78e5
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
3 changed files with 48 additions and 3 deletions

View File

@ -228,6 +228,11 @@ class DummyRequest extends HttpServletRequestWrapper {
public void setQueryString(String queryString) { public void setQueryString(String queryString) {
this.queryString = queryString; this.queryString = queryString;
} }
@Override
public String getServerName() {
return null;
}
} }
final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler { final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.function.Predicate;
/** /**
* <p> * <p>
@ -66,10 +67,15 @@ import java.util.Set;
* Rejects URLs that contain a URL encoded percent. See * Rejects URLs that contain a URL encoded percent. See
* {@link #setAllowUrlEncodedPercent(boolean)} * {@link #setAllowUrlEncodedPercent(boolean)}
* </li> * </li>
* <li>
* Rejects hosts that are not allowed. See
* {@link #setAllowedHostnames(Predicate)}
* </li>
* </ul> * </ul>
* *
* @see DefaultHttpFirewall * @see DefaultHttpFirewall
* @author Rob Winch * @author Rob Winch
* @author Eddú Meléndez
* @since 4.2.4 * @since 4.2.4
*/ */
public class StrictHttpFirewall implements HttpFirewall { public class StrictHttpFirewall implements HttpFirewall {
@ -96,6 +102,8 @@ public class StrictHttpFirewall implements HttpFirewall {
private Set<String> allowedHttpMethods = createDefaultAllowedHttpMethods(); private Set<String> allowedHttpMethods = createDefaultAllowedHttpMethods();
private Predicate<String> allowedHostnames = hostname -> true;
public StrictHttpFirewall() { public StrictHttpFirewall() {
urlBlacklistsAddAll(FORBIDDEN_SEMICOLON); urlBlacklistsAddAll(FORBIDDEN_SEMICOLON);
urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH); urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH);
@ -277,6 +285,13 @@ public class StrictHttpFirewall implements HttpFirewall {
} }
} }
public void setAllowedHostnames(Predicate<String> allowedHostnames) {
if (allowedHostnames == null) {
throw new IllegalArgumentException("allowedHostnames cannot be null");
}
this.allowedHostnames = allowedHostnames;
}
private void urlBlacklistsAddAll(Collection<String> values) { private void urlBlacklistsAddAll(Collection<String> values) {
this.encodedUrlBlacklist.addAll(values); this.encodedUrlBlacklist.addAll(values);
this.decodedUrlBlacklist.addAll(values); this.decodedUrlBlacklist.addAll(values);
@ -291,6 +306,7 @@ public class StrictHttpFirewall implements HttpFirewall {
public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException { public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException {
rejectForbiddenHttpMethod(request); rejectForbiddenHttpMethod(request);
rejectedBlacklistedUrls(request); rejectedBlacklistedUrls(request);
rejectedUntrustedHosts(request);
if (!isNormalized(request)) { if (!isNormalized(request)) {
throw new RequestRejectedException("The request was rejected because the URL was not normalized."); throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
@ -332,6 +348,13 @@ public class StrictHttpFirewall implements HttpFirewall {
} }
} }
private void rejectedUntrustedHosts(HttpServletRequest request) {
String serverName = request.getServerName();
if (serverName != null && !this.allowedHostnames.test(serverName)) {
throw new RequestRejectedException("The request was rejected because the domain " + serverName + " is untrusted.");
}
}
@Override @Override
public HttpServletResponse getFirewalledResponse(HttpServletResponse response) { public HttpServletResponse getFirewalledResponse(HttpServletResponse response) {
return new FirewalledResponse(response); return new FirewalledResponse(response);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ import org.springframework.mock.web.MockHttpServletRequest;
/** /**
* @author Rob Winch * @author Rob Winch
* @author Eddú Meléndez
*/ */
public class StrictHttpFirewallTests { public class StrictHttpFirewallTests {
public String[] unnormalizedPaths = { "/..", "/./path/", "/path/path/.", "/path/path//.", "./path/../path//.", public String[] unnormalizedPaths = { "/..", "/./path/", "/path/path/.", "/path/path//.", "./path/../path//.",
@ -428,4 +429,20 @@ public class StrictHttpFirewallTests {
this.firewall.getFirewalledRequest(request); this.firewall.getFirewalledRequest(request);
} }
@Test
public void getFirewalledRequestWhenTrustedDomainThenNoException() {
this.request.addHeader("Host", "example.org");
this.firewall.setAllowedHostnames(hostname -> hostname.equals("example.org"));
assertThatCode(() -> this.firewall.getFirewalledRequest(this.request)).doesNotThrowAnyException();
}
@Test(expected = RequestRejectedException.class)
public void getFirewalledRequestWhenUntrustedDomainThenException() {
this.request.addHeader("Host", "example.org");
this.firewall.setAllowedHostnames(hostname -> hostname.equals("myexample.org"));
this.firewall.getFirewalledRequest(this.request);
}
} }