HBASE-15187 Integrate CSRF prevention filter to REST gateway

This commit is contained in:
tedyu 2016-04-15 09:38:31 -07:00
parent 3e42f82600
commit 6930da781b
7 changed files with 587 additions and 35 deletions

View File

@ -1467,6 +1467,31 @@ possible configurations would overwhelm and obscure the important.
http://hbase.apache.org/devapidocs/org/apache/hadoop/hbase/master/normalizer/SimpleRegionNormalizer.html http://hbase.apache.org/devapidocs/org/apache/hadoop/hbase/master/normalizer/SimpleRegionNormalizer.html
</description> </description>
</property> </property>
<property>
<name>hbase.rest.csrf.enabled</name>
<value>false</value>
<description>
Set to true to enable protection against cross-site request forgery (CSRF)
</description>
</property>
<property>
<name>hbase.rest-csrf.browser-useragents-regex</name>
<value>^Mozilla.*,^Opera.*</value>
<description>
A comma-separated list of regular expressions used to match against an HTTP
request's User-Agent header when protection against cross-site request
forgery (CSRF) is enabled for REST server by setting
hbase.rest.csrf.enabled to true. If the incoming User-Agent matches
any of these regular expressions, then the request is considered to be sent
by a browser, and therefore CSRF prevention is enforced. If the request's
User-Agent does not match any of these regular expressions, then the request
is considered to be sent by something other than a browser, such as scripted
automation. In this case, CSRF is not a potential attack vector, so
the prevention is not enforced. This helps achieve backwards-compatibility
with existing automation that has not been updated to send the CSRF
prevention header.
</description>
</property>
<property> <property>
<name>hbase.security.exec.permission.checks</name> <name>hbase.security.exec.permission.checks</name>
<value>false</value> <value>false</value>

View File

@ -19,9 +19,11 @@
package org.apache.hadoop.hbase.rest; package org.apache.hadoop.hbase.rest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Set;
import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.HelpFormatter;
@ -35,13 +37,17 @@ import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.HBaseConfiguration; import org.apache.hadoop.hbase.HBaseConfiguration;
import org.apache.hadoop.hbase.HBaseInterfaceAudience; import org.apache.hadoop.hbase.HBaseInterfaceAudience;
import org.apache.hadoop.hbase.http.HttpServer;
import org.apache.hadoop.hbase.http.InfoServer; import org.apache.hadoop.hbase.http.InfoServer;
import org.apache.hadoop.hbase.rest.filter.AuthFilter; import org.apache.hadoop.hbase.rest.filter.AuthFilter;
import org.apache.hadoop.hbase.rest.filter.RestCsrfPreventionFilter;
import org.apache.hadoop.hbase.security.UserProvider; import org.apache.hadoop.hbase.security.UserProvider;
import org.apache.hadoop.hbase.util.DNS; import org.apache.hadoop.hbase.util.DNS;
import org.apache.hadoop.hbase.util.HttpServerUtil; import org.apache.hadoop.hbase.util.HttpServerUtil;
import org.apache.hadoop.hbase.util.Pair;
import org.apache.hadoop.hbase.util.Strings; import org.apache.hadoop.hbase.util.Strings;
import org.apache.hadoop.hbase.util.VersionInfo; import org.apache.hadoop.hbase.util.VersionInfo;
import org.apache.hadoop.util.StringUtils;
import org.mortbay.jetty.Connector; import org.mortbay.jetty.Connector;
import org.mortbay.jetty.Server; import org.mortbay.jetty.Server;
import org.mortbay.jetty.nio.SelectChannelConnector; import org.mortbay.jetty.nio.SelectChannelConnector;
@ -66,6 +72,15 @@ import com.sun.jersey.spi.container.servlet.ServletContainer;
*/ */
@InterfaceAudience.LimitedPrivate(HBaseInterfaceAudience.TOOLS) @InterfaceAudience.LimitedPrivate(HBaseInterfaceAudience.TOOLS)
public class RESTServer implements Constants { public class RESTServer implements Constants {
static Log LOG = LogFactory.getLog("RESTServer");
static String REST_CSRF_ENABLED_KEY = "hbase.rest.csrf.enabled";
static boolean REST_CSRF_ENABLED_DEFAULT = false;
static boolean restCSRFEnabled = false;
static String REST_CSRF_CUSTOM_HEADER_KEY ="hbase.rest.csrf.custom.header";
static String REST_CSRF_CUSTOM_HEADER_DEFAULT = "X-XSRF-HEADER";
static String REST_CSRF_METHODS_TO_IGNORE_KEY = "hbase.rest.csrf.methods.to.ignore";
static String REST_CSRF_METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
private static void printUsageAndExit(Options options, int exitCode) { private static void printUsageAndExit(Options options, int exitCode) {
HelpFormatter formatter = new HelpFormatter(); HelpFormatter formatter = new HelpFormatter();
@ -76,19 +91,42 @@ public class RESTServer implements Constants {
} }
/** /**
* The main method for the HBase rest server. * Returns a list of strings from a comma-delimited configuration value.
* @param args command-line arguments *
* @throws Exception exception * @param conf configuration to check
* @param name configuration property name
* @param defaultValue default value if no value found for name
* @return list of strings from comma-delimited configuration value, or an
* empty list if not found
*/ */
public static void main(String[] args) throws Exception { private static List<String> getTrimmedStringList(Configuration conf,
Log LOG = LogFactory.getLog("RESTServer"); String name, String defaultValue) {
String valueString = conf.get(name, defaultValue);
if (valueString == null) {
return new ArrayList<>();
}
return new ArrayList<>(StringUtils.getTrimmedStringCollection(valueString));
}
static String REST_CSRF_BROWSER_USERAGENTS_REGEX_KEY = "hbase.rest-csrf.browser-useragents-regex";
static void addCSRFFilter(Context context, Configuration conf) {
restCSRFEnabled = conf.getBoolean(REST_CSRF_ENABLED_KEY, REST_CSRF_ENABLED_DEFAULT);
if (restCSRFEnabled) {
String[] urls = { "/*" };
Set<String> restCsrfMethodsToIgnore = new HashSet<>();
restCsrfMethodsToIgnore.addAll(getTrimmedStringList(conf,
REST_CSRF_METHODS_TO_IGNORE_KEY, REST_CSRF_METHODS_TO_IGNORE_DEFAULT));
Map<String, String> restCsrfParams = RestCsrfPreventionFilter
.getFilterParams(conf, "hbase.rest-csrf.");
HttpServer.defineFilter(context, "csrf", RestCsrfPreventionFilter.class.getName(),
restCsrfParams, urls);
}
}
VersionInfo.logVersion();
FilterHolder authFilter = null;
Configuration conf = HBaseConfiguration.create();
Class<? extends ServletContainer> containerClass = ServletContainer.class;
UserProvider userProvider = UserProvider.instantiate(conf);
// login the server principal (if using secure Hadoop) // login the server principal (if using secure Hadoop)
private static Pair<FilterHolder, Class<? extends ServletContainer>> loginServerPrincipal(
UserProvider userProvider, Configuration conf) throws Exception {
Class<? extends ServletContainer> containerClass = ServletContainer.class;
if (userProvider.isHadoopSecurityEnabled() && userProvider.isHBaseSecurityEnabled()) { if (userProvider.isHadoopSecurityEnabled() && userProvider.isHBaseSecurityEnabled()) {
String machineName = Strings.domainNamePointerToHostName( String machineName = Strings.domainNamePointerToHostName(
DNS.getDefaultHost(conf.get(REST_DNS_INTERFACE, "default"), DNS.getDefaultHost(conf.get(REST_DNS_INTERFACE, "default"),
@ -102,14 +140,16 @@ public class RESTServer implements Constants {
userProvider.login(REST_KEYTAB_FILE, REST_KERBEROS_PRINCIPAL, machineName); userProvider.login(REST_KEYTAB_FILE, REST_KERBEROS_PRINCIPAL, machineName);
if (conf.get(REST_AUTHENTICATION_TYPE) != null) { if (conf.get(REST_AUTHENTICATION_TYPE) != null) {
containerClass = RESTServletContainer.class; containerClass = RESTServletContainer.class;
authFilter = new FilterHolder(); FilterHolder authFilter = new FilterHolder();
authFilter.setClassName(AuthFilter.class.getName()); authFilter.setClassName(AuthFilter.class.getName());
authFilter.setName("AuthenticationFilter"); authFilter.setName("AuthenticationFilter");
return new Pair<FilterHolder, Class<? extends ServletContainer>>(authFilter,containerClass);
} }
} }
return new Pair<FilterHolder, Class<? extends ServletContainer>>(null, containerClass);
}
RESTServlet servlet = RESTServlet.getInstance(conf, userProvider); private static void parseCommandLine(String[] args, RESTServlet servlet) {
Options options = new Options(); Options options = new Options();
options.addOption("p", "port", true, "Port to bind to [default: 8080]"); options.addOption("p", "port", true, "Port to bind to [default: 8080]");
options.addOption("ro", "readonly", false, "Respond only to GET HTTP " + options.addOption("ro", "readonly", false, "Respond only to GET HTTP " +
@ -159,6 +199,24 @@ public class RESTServer implements Constants {
} else { } else {
printUsageAndExit(options, 1); printUsageAndExit(options, 1);
} }
}
/**
* The main method for the HBase rest server.
* @param args command-line arguments
* @throws Exception exception
*/
public static void main(String[] args) throws Exception {
VersionInfo.logVersion();
Configuration conf = HBaseConfiguration.create();
UserProvider userProvider = UserProvider.instantiate(conf);
Pair<FilterHolder, Class<? extends ServletContainer>> pair = loginServerPrincipal(
userProvider, conf);
FilterHolder authFilter = pair.getFirst();
Class<? extends ServletContainer> containerClass = pair.getSecond();
RESTServlet servlet = RESTServlet.getInstance(conf, userProvider);
parseCommandLine(args, servlet);
// set up the Jersey servlet container for Jetty // set up the Jersey servlet container for Jetty
ServletHolder sh = new ServletHolder(containerClass); ServletHolder sh = new ServletHolder(containerClass);
@ -234,6 +292,7 @@ public class RESTServer implements Constants {
filter = filter.trim(); filter = filter.trim();
context.addFilter(Class.forName(filter), "/*", 0); context.addFilter(Class.forName(filter), "/*", 0);
} }
addCSRFFilter(context, conf);
HttpServerUtil.constrainHttpMethods(context); HttpServerUtil.constrainHttpMethods(context);
// Put up info server. // Put up info server.
@ -245,7 +304,6 @@ public class RESTServer implements Constants {
infoServer.setAttribute("hbase.conf", conf); infoServer.setAttribute("hbase.conf", conf);
infoServer.start(); infoServer.start();
} }
// start server // start server
server.start(); server.start();
server.join(); server.join();

View File

@ -373,6 +373,20 @@ public class Client {
return put(cluster, path, contentType, content); return put(cluster, path, contentType, content);
} }
/**
* Send a PUT request
* @param path the path or URI
* @param contentType the content MIME type
* @param content the content bytes
* @param extraHdr extra Header to send
* @return a Response object with response detail
* @throws IOException
*/
public Response put(String path, String contentType, byte[] content, Header extraHdr)
throws IOException {
return put(cluster, path, contentType, content, extraHdr);
}
/** /**
* Send a PUT request * Send a PUT request
* @param cluster the cluster definition * @param cluster the cluster definition
@ -380,7 +394,7 @@ public class Client {
* @param contentType the content MIME type * @param contentType the content MIME type
* @param content the content bytes * @param content the content bytes
* @return a Response object with response detail * @return a Response object with response detail
* @throws IOException * @throws IOException for error
*/ */
public Response put(Cluster cluster, String path, String contentType, public Response put(Cluster cluster, String path, String contentType,
byte[] content) throws IOException { byte[] content) throws IOException {
@ -389,6 +403,27 @@ public class Client {
return put(cluster, path, headers, content); return put(cluster, path, headers, content);
} }
/**
* Send a PUT request
* @param cluster the cluster definition
* @param path the path or URI
* @param contentType the content MIME type
* @param content the content bytes
* @param extraHdr additional Header to send
* @return a Response object with response detail
* @throws IOException for error
*/
public Response put(Cluster cluster, String path, String contentType,
byte[] content, Header extraHdr) throws IOException {
int cnt = extraHdr == null ? 1 : 2;
Header[] headers = new Header[cnt];
headers[0] = new Header("Content-Type", contentType);
if (extraHdr != null) {
headers[1] = extraHdr;
}
return put(cluster, path, headers, content);
}
/** /**
* Send a PUT request * Send a PUT request
* @param path the path or URI * @param path the path or URI
@ -440,6 +475,20 @@ public class Client {
return post(cluster, path, contentType, content); return post(cluster, path, contentType, content);
} }
/**
* Send a POST request
* @param path the path or URI
* @param contentType the content MIME type
* @param content the content bytes
* @param extraHdr additional Header to send
* @return a Response object with response detail
* @throws IOException
*/
public Response post(String path, String contentType, byte[] content, Header extraHdr)
throws IOException {
return post(cluster, path, contentType, content, extraHdr);
}
/** /**
* Send a POST request * Send a POST request
* @param cluster the cluster definition * @param cluster the cluster definition
@ -447,7 +496,7 @@ public class Client {
* @param contentType the content MIME type * @param contentType the content MIME type
* @param content the content bytes * @param content the content bytes
* @return a Response object with response detail * @return a Response object with response detail
* @throws IOException * @throws IOException for error
*/ */
public Response post(Cluster cluster, String path, String contentType, public Response post(Cluster cluster, String path, String contentType,
byte[] content) throws IOException { byte[] content) throws IOException {
@ -456,6 +505,27 @@ public class Client {
return post(cluster, path, headers, content); return post(cluster, path, headers, content);
} }
/**
* Send a POST request
* @param cluster the cluster definition
* @param path the path or URI
* @param contentType the content MIME type
* @param content the content bytes
* @param extraHdr additional Header to send
* @return a Response object with response detail
* @throws IOException for error
*/
public Response post(Cluster cluster, String path, String contentType,
byte[] content, Header extraHdr) throws IOException {
int cnt = extraHdr == null ? 1 : 2;
Header[] headers = new Header[cnt];
headers[0] = new Header("Content-Type", contentType);
if (extraHdr != null) {
headers[1] = extraHdr;
}
return post(cluster, path, headers, content);
}
/** /**
* Send a POST request * Send a POST request
* @param path the path or URI * @param path the path or URI
@ -504,12 +574,23 @@ public class Client {
return delete(cluster, path); return delete(cluster, path);
} }
/**
* Send a DELETE request
* @param path the path or URI
* @param extraHdr additional Header to send
* @return a Response object with response detail
* @throws IOException
*/
public Response delete(String path, Header extraHdr) throws IOException {
return delete(cluster, path, extraHdr);
}
/** /**
* Send a DELETE request * Send a DELETE request
* @param cluster the cluster definition * @param cluster the cluster definition
* @param path the path or URI * @param path the path or URI
* @return a Response object with response detail * @return a Response object with response detail
* @throws IOException * @throws IOException for error
*/ */
public Response delete(Cluster cluster, String path) throws IOException { public Response delete(Cluster cluster, String path) throws IOException {
DeleteMethod method = new DeleteMethod(); DeleteMethod method = new DeleteMethod();
@ -522,4 +603,24 @@ public class Client {
method.releaseConnection(); method.releaseConnection();
} }
} }
/**
* Send a DELETE request
* @param cluster the cluster definition
* @param path the path or URI
* @return a Response object with response detail
* @throws IOException for error
*/
public Response delete(Cluster cluster, String path, Header extraHdr) throws IOException {
DeleteMethod method = new DeleteMethod();
try {
Header[] headers = { extraHdr };
int code = execute(cluster, method, headers, path);
headers = method.getResponseHeaders();
byte[] content = method.getResponseBody();
return new Response(code, headers, content);
} finally {
method.releaseConnection();
}
}
} }

View File

@ -0,0 +1,287 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.rest.filter;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This filter provides protection against cross site request forgery (CSRF)
* attacks for REST APIs. Enabling this filter on an endpoint results in the
* requirement of all client to send a particular (configurable) HTTP header
* with every request. In the absense of this header the filter will reject the
* attempt as a bad request.
*/
@InterfaceAudience.Public
@InterfaceStability.Evolving
public class RestCsrfPreventionFilter implements Filter {
private static final Logger LOG =
LoggerFactory.getLogger(RestCsrfPreventionFilter.class);
public static final String HEADER_USER_AGENT = "User-Agent";
public static final String BROWSER_USER_AGENT_PARAM =
"browser-useragents-regex";
public static final String CUSTOM_HEADER_PARAM = "custom-header";
public static final String CUSTOM_METHODS_TO_IGNORE_PARAM =
"methods-to-ignore";
static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*";
public static final String HEADER_DEFAULT = "X-XSRF-HEADER";
static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
private String headerName = HEADER_DEFAULT;
private Set<String> methodsToIgnore = null;
private Set<Pattern> browserUserAgents;
@Override
public void init(FilterConfig filterConfig) throws ServletException {
String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM);
if (customHeader != null) {
headerName = customHeader;
}
String customMethodsToIgnore =
filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM);
if (customMethodsToIgnore != null) {
parseMethodsToIgnore(customMethodsToIgnore);
} else {
parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT);
}
String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM);
if (agents == null) {
agents = BROWSER_USER_AGENTS_DEFAULT;
}
parseBrowserUserAgents(agents);
LOG.info("Adding cross-site request forgery (CSRF) protection, "
+ "headerName = {}, methodsToIgnore = {}, browserUserAgents = {}",
headerName, methodsToIgnore, browserUserAgents);
}
void parseBrowserUserAgents(String userAgents) {
String[] agentsArray = userAgents.split(",");
browserUserAgents = new HashSet<Pattern>();
for (String patternString : agentsArray) {
browserUserAgents.add(Pattern.compile(patternString));
}
}
void parseMethodsToIgnore(String mti) {
String[] methods = mti.split(",");
methodsToIgnore = new HashSet<String>();
for (int i = 0; i < methods.length; i++) {
methodsToIgnore.add(methods[i]);
}
}
/**
* This method interrogates the User-Agent String and returns whether it
* refers to a browser. If its not a browser, then the requirement for the
* CSRF header will not be enforced; if it is a browser, the requirement will
* be enforced.
* <p>
* A User-Agent String is considered to be a browser if it matches
* any of the regex patterns from browser-useragent-regex; the default
* behavior is to consider everything a browser that matches the following:
* "^Mozilla.*,^Opera.*". Subclasses can optionally override
* this method to use different behavior.
*
* @param userAgent The User-Agent String, or null if there isn't one
* @return true if the User-Agent String refers to a browser, false if not
*/
protected boolean isBrowser(String userAgent) {
if (userAgent == null) {
return false;
}
for (Pattern pattern : browserUserAgents) {
Matcher matcher = pattern.matcher(userAgent);
if (matcher.matches()) {
return true;
}
}
return false;
}
/**
* Defines the minimal API requirements for the filter to execute its
* filtering logic. This interface exists to facilitate integration in
* components that do not run within a servlet container and therefore cannot
* rely on a servlet container to dispatch to the {@link #doFilter} method.
* Applications that do run inside a servlet container will not need to write
* code that uses this interface. Instead, they can use typical servlet
* container configuration mechanisms to insert the filter.
*/
public interface HttpInteraction {
/**
* Returns the value of a header.
*
* @param header name of header
* @return value of header
*/
String getHeader(String header);
/**
* Returns the method.
*
* @return method
*/
String getMethod();
/**
* Called by the filter after it decides that the request may proceed.
*
* @throws IOException if there is an I/O error
* @throws ServletException if the implementation relies on the servlet API
* and a servlet API call has failed
*/
void proceed() throws IOException, ServletException;
/**
* Called by the filter after it decides that the request is a potential
* CSRF attack and therefore must be rejected.
*
* @param code status code to send
* @param message response message
* @throws IOException if there is an I/O error
*/
void sendError(int code, String message) throws IOException;
}
/**
* Handles an {@link HttpInteraction} by applying the filtering logic.
*
* @param httpInteraction caller's HTTP interaction
* @throws IOException if there is an I/O error
* @throws ServletException if the implementation relies on the servlet API
* and a servlet API call has failed
*/
public void handleHttpInteraction(HttpInteraction httpInteraction)
throws IOException, ServletException {
if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) ||
methodsToIgnore.contains(httpInteraction.getMethod()) ||
httpInteraction.getHeader(headerName) != null) {
httpInteraction.proceed();
} else {
httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,
"Missing Required Header for CSRF Vulnerability Protection");
}
}
@Override
public void doFilter(ServletRequest request, ServletResponse response,
final FilterChain chain) throws IOException, ServletException {
final HttpServletRequest httpRequest = (HttpServletRequest)request;
final HttpServletResponse httpResponse = (HttpServletResponse)response;
handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest,
httpResponse, chain));
}
@Override
public void destroy() {
}
/**
* Constructs a mapping of configuration properties to be used for filter
* initialization. The mapping includes all properties that start with the
* specified configuration prefix. Property names in the mapping are trimmed
* to remove the configuration prefix.
*
* @param conf configuration to read
* @param confPrefix configuration prefix
* @return mapping of configuration properties to be used for filter
* initialization
*/
public static Map<String, String> getFilterParams(Configuration conf,
String confPrefix) {
Map<String, String> filterConfigMap = new HashMap<>();
for (Map.Entry<String, String> entry : conf) {
String name = entry.getKey();
if (name.startsWith(confPrefix)) {
String value = conf.get(name);
name = name.substring(confPrefix.length());
filterConfigMap.put(name, value);
}
}
return filterConfigMap;
}
/**
* {@link HttpInteraction} implementation for use in the servlet filter.
*/
private static final class ServletFilterHttpInteraction
implements HttpInteraction {
private final FilterChain chain;
private final HttpServletRequest httpRequest;
private final HttpServletResponse httpResponse;
/**
* Creates a new ServletFilterHttpInteraction.
*
* @param httpRequest request to process
* @param httpResponse response to process
* @param chain filter chain to forward to if HTTP interaction is allowed
*/
public ServletFilterHttpInteraction(HttpServletRequest httpRequest,
HttpServletResponse httpResponse, FilterChain chain) {
this.httpRequest = httpRequest;
this.httpResponse = httpResponse;
this.chain = chain;
}
@Override
public String getHeader(String header) {
return httpRequest.getHeader(header);
}
@Override
public String getMethod() {
return httpRequest.getMethod();
}
@Override
public void proceed() throws IOException, ServletException {
chain.doFilter(httpRequest, httpResponse);
}
@Override
public void sendError(int code, String message) throws IOException {
httpResponse.sendError(code, message);
}
}
}

View File

@ -75,6 +75,8 @@ public class HBaseRESTTestingUtility {
filter = filter.trim(); filter = filter.trim();
context.addFilter(Class.forName(filter), "/*", 0); context.addFilter(Class.forName(filter), "/*", 0);
} }
conf.set(RESTServer.REST_CSRF_BROWSER_USERAGENTS_REGEX_KEY, ".*");
RESTServer.addCSRFFilter(context, conf);
HttpServerUtil.constrainHttpMethods(context); HttpServerUtil.constrainHttpMethods(context);
LOG.info("Loaded filter classes :" + filterClasses); LOG.info("Loaded filter classes :" + filterClasses);
// start the server // start the server

View File

@ -18,6 +18,7 @@
*/ */
package org.apache.hadoop.hbase.rest; package org.apache.hadoop.hbase.rest;
import org.apache.commons.httpclient.Header;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.*; import org.apache.hadoop.hbase.*;
import org.apache.hadoop.hbase.client.Admin; import org.apache.hadoop.hbase.client.Admin;
@ -36,6 +37,8 @@ import org.junit.AfterClass;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
import org.junit.experimental.categories.Category; import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBContext;
@ -43,10 +46,14 @@ import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller; import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller; import javax.xml.bind.Unmarshaller;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@Category({RestTests.class, MediumTests.class}) @Category({RestTests.class, MediumTests.class})
@RunWith(Parameterized.class)
public class TestMultiRowResource { public class TestMultiRowResource {
private static final TableName TABLE = TableName.valueOf("TestRowResource"); private static final TableName TABLE = TableName.valueOf("TestRowResource");
@ -69,10 +76,27 @@ public class TestMultiRowResource {
private static Unmarshaller unmarshaller; private static Unmarshaller unmarshaller;
private static Configuration conf; private static Configuration conf;
private static Header extraHdr = null;
private static boolean csrfEnabled = true;
@Parameterized.Parameters
public static Collection<Object[]> data() {
List<Object[]> params = new ArrayList<Object[]>();
params.add(new Object[] {Boolean.TRUE});
params.add(new Object[] {Boolean.FALSE});
return params;
}
public TestMultiRowResource(Boolean csrf) {
csrfEnabled = csrf;
}
@BeforeClass @BeforeClass
public static void setUpBeforeClass() throws Exception { public static void setUpBeforeClass() throws Exception {
conf = TEST_UTIL.getConfiguration(); conf = TEST_UTIL.getConfiguration();
conf.setBoolean(RESTServer.REST_CSRF_ENABLED_KEY, csrfEnabled);
extraHdr = new Header(RESTServer.REST_CSRF_CUSTOM_HEADER_DEFAULT, "");
TEST_UTIL.startMiniCluster(); TEST_UTIL.startMiniCluster();
REST_TEST_UTIL.startServletContainer(conf); REST_TEST_UTIL.startServletContainer(conf);
context = JAXBContext.newInstance( context = JAXBContext.newInstance(
@ -113,16 +137,21 @@ public class TestMultiRowResource {
path.append("&row="); path.append("&row=");
path.append(ROW_2); path.append(ROW_2);
client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1)); if (csrfEnabled) {
client.post(row_6_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_2)); Response response = client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1));
assertEquals(400, response.getCode());
}
client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1), extraHdr);
client.post(row_6_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_2), extraHdr);
Response response = client.get(path.toString(), Constants.MIMETYPE_JSON); Response response = client.get(path.toString(), Constants.MIMETYPE_JSON);
assertEquals(response.getCode(), 200); assertEquals(response.getCode(), 200);
assertEquals(Constants.MIMETYPE_JSON, response.getHeader("content-type")); assertEquals(Constants.MIMETYPE_JSON, response.getHeader("content-type"));
client.delete(row_5_url); client.delete(row_5_url, extraHdr);
client.delete(row_6_url); client.delete(row_6_url, extraHdr);
} }
@ -140,16 +169,16 @@ public class TestMultiRowResource {
path.append("&row="); path.append("&row=");
path.append(ROW_2); path.append(ROW_2);
client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1)); client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1), extraHdr);
client.post(row_6_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_2)); client.post(row_6_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_2), extraHdr);
Response response = client.get(path.toString(), Constants.MIMETYPE_XML); Response response = client.get(path.toString(), Constants.MIMETYPE_XML);
assertEquals(response.getCode(), 200); assertEquals(response.getCode(), 200);
assertEquals(Constants.MIMETYPE_XML, response.getHeader("content-type")); assertEquals(Constants.MIMETYPE_XML, response.getHeader("content-type"));
client.delete(row_5_url); client.delete(row_5_url, extraHdr);
client.delete(row_6_url); client.delete(row_6_url, extraHdr);
} }
@ -165,7 +194,7 @@ public class TestMultiRowResource {
path.append("&row="); path.append("&row=");
path.append(ROW_2); path.append(ROW_2);
client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1)); client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1), extraHdr);
Response response = client.get(path.toString(), Constants.MIMETYPE_JSON); Response response = client.get(path.toString(), Constants.MIMETYPE_JSON);
assertEquals(response.getCode(), 200); assertEquals(response.getCode(), 200);
ObjectMapper mapper = new JacksonProvider().locateMapper(CellSetModel.class, ObjectMapper mapper = new JacksonProvider().locateMapper(CellSetModel.class,
@ -174,7 +203,7 @@ public class TestMultiRowResource {
assertEquals(1, cellSet.getRows().size()); assertEquals(1, cellSet.getRows().size());
assertEquals(ROW_1, Bytes.toString(cellSet.getRows().get(0).getKey())); assertEquals(ROW_1, Bytes.toString(cellSet.getRows().get(0).getKey()));
assertEquals(VALUE_1, Bytes.toString(cellSet.getRows().get(0).getCells().get(0).getValue())); assertEquals(VALUE_1, Bytes.toString(cellSet.getRows().get(0).getCells().get(0).getValue()));
client.delete(row_5_url); client.delete(row_5_url, extraHdr);
} }
} }

View File

@ -21,10 +21,15 @@ package org.apache.hadoop.hbase.rest;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.StringWriter; import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException; import javax.xml.bind.JAXBException;
import org.apache.commons.httpclient.Header;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.HBaseTestingUtility; import org.apache.hadoop.hbase.HBaseTestingUtility;
import org.apache.hadoop.hbase.TableName; import org.apache.hadoop.hbase.TableName;
@ -45,8 +50,11 @@ import org.junit.AfterClass;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
import org.junit.experimental.categories.Category; import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@Category({RestTests.class, MediumTests.class}) @Category({RestTests.class, MediumTests.class})
@RunWith(Parameterized.class)
public class TestSchemaResource { public class TestSchemaResource {
private static String TABLE1 = "TestSchemaResource1"; private static String TABLE1 = "TestSchemaResource1";
private static String TABLE2 = "TestSchemaResource2"; private static String TABLE2 = "TestSchemaResource2";
@ -58,10 +66,27 @@ public class TestSchemaResource {
private static JAXBContext context; private static JAXBContext context;
private static Configuration conf; private static Configuration conf;
private static TestTableSchemaModel testTableSchemaModel; private static TestTableSchemaModel testTableSchemaModel;
private static Header extraHdr = null;
private static boolean csrfEnabled = true;
@Parameterized.Parameters
public static Collection<Object[]> data() {
List<Object[]> params = new ArrayList<Object[]>();
params.add(new Object[] {Boolean.TRUE});
params.add(new Object[] {Boolean.FALSE});
return params;
}
public TestSchemaResource(Boolean csrf) {
csrfEnabled = csrf;
}
@BeforeClass @BeforeClass
public static void setUpBeforeClass() throws Exception { public static void setUpBeforeClass() throws Exception {
conf = TEST_UTIL.getConfiguration(); conf = TEST_UTIL.getConfiguration();
conf.setBoolean(RESTServer.REST_CSRF_ENABLED_KEY, csrfEnabled);
extraHdr = new Header(RESTServer.REST_CSRF_CUSTOM_HEADER_DEFAULT, "");
TEST_UTIL.startMiniCluster(); TEST_UTIL.startMiniCluster();
REST_TEST_UTIL.startServletContainer(conf); REST_TEST_UTIL.startServletContainer(conf);
client = new Client(new Cluster().add("localhost", client = new Client(new Cluster().add("localhost",
@ -102,12 +127,18 @@ public class TestSchemaResource {
// create the table // create the table
model = testTableSchemaModel.buildTestModel(TABLE1); model = testTableSchemaModel.buildTestModel(TABLE1);
testTableSchemaModel.checkModel(model, TABLE1); testTableSchemaModel.checkModel(model, TABLE1);
if (csrfEnabled) {
// test put operation is forbidden without custom header
response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model)); response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model));
assertEquals(response.getCode(), 400);
}
response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model), extraHdr);
assertEquals(response.getCode(), 201); assertEquals(response.getCode(), 201);
// recall the same put operation but in read-only mode // recall the same put operation but in read-only mode
conf.set("hbase.rest.readonly", "true"); conf.set("hbase.rest.readonly", "true");
response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model)); response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model), extraHdr);
assertEquals(response.getCode(), 403); assertEquals(response.getCode(), 403);
// retrieve the schema and validate it // retrieve the schema and validate it
@ -124,15 +155,21 @@ public class TestSchemaResource {
model = testTableSchemaModel.fromJSON(Bytes.toString(response.getBody())); model = testTableSchemaModel.fromJSON(Bytes.toString(response.getBody()));
testTableSchemaModel.checkModel(model, TABLE1); testTableSchemaModel.checkModel(model, TABLE1);
// test delete schema operation is forbidden in read-only mode if (csrfEnabled) {
// test delete schema operation is forbidden without custom header
response = client.delete(schemaPath); response = client.delete(schemaPath);
assertEquals(400, response.getCode());
}
// test delete schema operation is forbidden in read-only mode
response = client.delete(schemaPath, extraHdr);
assertEquals(response.getCode(), 403); assertEquals(response.getCode(), 403);
// return read-only setting back to default // return read-only setting back to default
conf.set("hbase.rest.readonly", "false"); conf.set("hbase.rest.readonly", "false");
// delete the table and make sure HBase concurs // delete the table and make sure HBase concurs
response = client.delete(schemaPath); response = client.delete(schemaPath, extraHdr);
assertEquals(response.getCode(), 200); assertEquals(response.getCode(), 200);
assertFalse(admin.tableExists(TableName.valueOf(TABLE1))); assertFalse(admin.tableExists(TableName.valueOf(TABLE1)));
} }
@ -149,14 +186,21 @@ public class TestSchemaResource {
// create the table // create the table
model = testTableSchemaModel.buildTestModel(TABLE2); model = testTableSchemaModel.buildTestModel(TABLE2);
testTableSchemaModel.checkModel(model, TABLE2); testTableSchemaModel.checkModel(model, TABLE2);
if (csrfEnabled) {
// test put operation is forbidden without custom header
response = client.put(schemaPath, Constants.MIMETYPE_PROTOBUF, model.createProtobufOutput());
assertEquals(response.getCode(), 400);
}
response = client.put(schemaPath, Constants.MIMETYPE_PROTOBUF, response = client.put(schemaPath, Constants.MIMETYPE_PROTOBUF,
model.createProtobufOutput()); model.createProtobufOutput(), extraHdr);
assertEquals(response.getCode(), 201); assertEquals(response.getCode(), 201);
// recall the same put operation but in read-only mode // recall the same put operation but in read-only mode
conf.set("hbase.rest.readonly", "true"); conf.set("hbase.rest.readonly", "true");
response = client.put(schemaPath, Constants.MIMETYPE_PROTOBUF, response = client.put(schemaPath, Constants.MIMETYPE_PROTOBUF,
model.createProtobufOutput()); model.createProtobufOutput(), extraHdr);
assertNotNull(extraHdr);
assertEquals(response.getCode(), 403); assertEquals(response.getCode(), 403);
// retrieve the schema and validate it // retrieve the schema and validate it
@ -175,15 +219,21 @@ public class TestSchemaResource {
model.getObjectFromMessage(response.getBody()); model.getObjectFromMessage(response.getBody());
testTableSchemaModel.checkModel(model, TABLE2); testTableSchemaModel.checkModel(model, TABLE2);
// test delete schema operation is forbidden in read-only mode if (csrfEnabled) {
// test delete schema operation is forbidden without custom header
response = client.delete(schemaPath); response = client.delete(schemaPath);
assertEquals(400, response.getCode());
}
// test delete schema operation is forbidden in read-only mode
response = client.delete(schemaPath, extraHdr);
assertEquals(response.getCode(), 403); assertEquals(response.getCode(), 403);
// return read-only setting back to default // return read-only setting back to default
conf.set("hbase.rest.readonly", "false"); conf.set("hbase.rest.readonly", "false");
// delete the table and make sure HBase concurs // delete the table and make sure HBase concurs
response = client.delete(schemaPath); response = client.delete(schemaPath, extraHdr);
assertEquals(response.getCode(), 200); assertEquals(response.getCode(), 200);
assertFalse(admin.tableExists(TableName.valueOf(TABLE2))); assertFalse(admin.tableExists(TableName.valueOf(TABLE2)));
} }