diff --git a/hbase-common/src/main/resources/hbase-default.xml b/hbase-common/src/main/resources/hbase-default.xml index 94f1e6dc4c7..fae986edb66 100644 --- a/hbase-common/src/main/resources/hbase-default.xml +++ b/hbase-common/src/main/resources/hbase-default.xml @@ -1355,6 +1355,31 @@ possible configurations would overwhelm and obscure the important. as the SimpleLoadBalancer). + + hbase.rest.csrf.enabled + false + + Set to true to enable protection against cross-site request forgery (CSRF) + + + + hbase.rest-csrf.browser-useragents-regex + ^Mozilla.*,^Opera.* + + A comma-separated list of regular expressions used to match against an HTTP + request's User-Agent header when protection against cross-site request + forgery (CSRF) is enabled for REST server by setting + hbase.rest.csrf.enabled to true. If the incoming User-Agent matches + any of these regular expressions, then the request is considered to be sent + by a browser, and therefore CSRF prevention is enforced. If the request's + User-Agent does not match any of these regular expressions, then the request + is considered to be sent by something other than a browser, such as scripted + automation. In this case, CSRF is not a potential attack vector, so + the prevention is not enforced. This helps achieve backwards-compatibility + with existing automation that has not been updated to send the CSRF + prevention header. + + hbase.security.exec.permission.checks false diff --git a/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/RESTServer.java b/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/RESTServer.java index ad8c65dd847..9dac84a51fb 100644 --- a/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/RESTServer.java +++ b/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/RESTServer.java @@ -19,9 +19,11 @@ package org.apache.hadoop.hbase.rest; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.HelpFormatter; @@ -35,13 +37,17 @@ import org.apache.hadoop.hbase.classification.InterfaceAudience; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.HBaseConfiguration; import org.apache.hadoop.hbase.HBaseInterfaceAudience; +import org.apache.hadoop.hbase.http.HttpServer; import org.apache.hadoop.hbase.http.InfoServer; import org.apache.hadoop.hbase.rest.filter.AuthFilter; +import org.apache.hadoop.hbase.rest.filter.RestCsrfPreventionFilter; import org.apache.hadoop.hbase.security.UserProvider; import org.apache.hadoop.hbase.util.DNS; import org.apache.hadoop.hbase.util.HttpServerUtil; +import org.apache.hadoop.hbase.util.Pair; import org.apache.hadoop.hbase.util.Strings; import org.apache.hadoop.hbase.util.VersionInfo; +import org.apache.hadoop.util.StringUtils; import org.mortbay.jetty.Connector; import org.mortbay.jetty.Server; import org.mortbay.jetty.nio.SelectChannelConnector; @@ -66,6 +72,15 @@ import com.sun.jersey.spi.container.servlet.ServletContainer; */ @InterfaceAudience.LimitedPrivate(HBaseInterfaceAudience.TOOLS) public class RESTServer implements Constants { + static Log LOG = LogFactory.getLog("RESTServer"); + + static String REST_CSRF_ENABLED_KEY = "hbase.rest.csrf.enabled"; + static boolean REST_CSRF_ENABLED_DEFAULT = false; + static boolean restCSRFEnabled = false; + static String REST_CSRF_CUSTOM_HEADER_KEY ="hbase.rest.csrf.custom.header"; + static String REST_CSRF_CUSTOM_HEADER_DEFAULT = "X-XSRF-HEADER"; + static String REST_CSRF_METHODS_TO_IGNORE_KEY = "hbase.rest.csrf.methods.to.ignore"; + static String REST_CSRF_METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE"; private static void printUsageAndExit(Options options, int exitCode) { HelpFormatter formatter = new HelpFormatter(); @@ -76,19 +91,42 @@ public class RESTServer implements Constants { } /** - * The main method for the HBase rest server. - * @param args command-line arguments - * @throws Exception exception + * Returns a list of strings from a comma-delimited configuration value. + * + * @param conf configuration to check + * @param name configuration property name + * @param defaultValue default value if no value found for name + * @return list of strings from comma-delimited configuration value, or an + * empty list if not found */ - public static void main(String[] args) throws Exception { - Log LOG = LogFactory.getLog("RESTServer"); + private static List getTrimmedStringList(Configuration conf, + String name, String defaultValue) { + String valueString = conf.get(name, defaultValue); + if (valueString == null) { + return new ArrayList<>(); + } + return new ArrayList<>(StringUtils.getTrimmedStringCollection(valueString)); + } - VersionInfo.logVersion(); - FilterHolder authFilter = null; - Configuration conf = HBaseConfiguration.create(); + static String REST_CSRF_BROWSER_USERAGENTS_REGEX_KEY = "hbase.rest-csrf.browser-useragents-regex"; + static void addCSRFFilter(Context context, Configuration conf) { + restCSRFEnabled = conf.getBoolean(REST_CSRF_ENABLED_KEY, REST_CSRF_ENABLED_DEFAULT); + if (restCSRFEnabled) { + String[] urls = { "/*" }; + Set restCsrfMethodsToIgnore = new HashSet<>(); + restCsrfMethodsToIgnore.addAll(getTrimmedStringList(conf, + REST_CSRF_METHODS_TO_IGNORE_KEY, REST_CSRF_METHODS_TO_IGNORE_DEFAULT)); + Map restCsrfParams = RestCsrfPreventionFilter + .getFilterParams(conf, "hbase.rest-csrf."); + HttpServer.defineFilter(context, "csrf", RestCsrfPreventionFilter.class.getName(), + restCsrfParams, urls); + } + } + + // login the server principal (if using secure Hadoop) + private static Pair> loginServerPrincipal( + UserProvider userProvider, Configuration conf) throws Exception { Class containerClass = ServletContainer.class; - UserProvider userProvider = UserProvider.instantiate(conf); - // login the server principal (if using secure Hadoop) if (userProvider.isHadoopSecurityEnabled() && userProvider.isHBaseSecurityEnabled()) { String machineName = Strings.domainNamePointerToHostName( DNS.getDefaultHost(conf.get(REST_DNS_INTERFACE, "default"), @@ -102,14 +140,16 @@ public class RESTServer implements Constants { userProvider.login(REST_KEYTAB_FILE, REST_KERBEROS_PRINCIPAL, machineName); if (conf.get(REST_AUTHENTICATION_TYPE) != null) { containerClass = RESTServletContainer.class; - authFilter = new FilterHolder(); + FilterHolder authFilter = new FilterHolder(); authFilter.setClassName(AuthFilter.class.getName()); authFilter.setName("AuthenticationFilter"); + return new Pair>(authFilter,containerClass); } } + return new Pair>(null, containerClass); + } - RESTServlet servlet = RESTServlet.getInstance(conf, userProvider); - + private static void parseCommandLine(String[] args, RESTServlet servlet) { Options options = new Options(); options.addOption("p", "port", true, "Port to bind to [default: 8080]"); options.addOption("ro", "readonly", false, "Respond only to GET HTTP " + @@ -161,6 +201,24 @@ public class RESTServer implements Constants { } else { printUsageAndExit(options, 1); } + } + + /** + * The main method for the HBase rest server. + * @param args command-line arguments + * @throws Exception exception + */ + public static void main(String[] args) throws Exception { + VersionInfo.logVersion(); + Configuration conf = HBaseConfiguration.create(); + UserProvider userProvider = UserProvider.instantiate(conf); + Pair> pair = loginServerPrincipal( + userProvider, conf); + FilterHolder authFilter = pair.getFirst(); + Class containerClass = pair.getSecond(); + RESTServlet servlet = RESTServlet.getInstance(conf, userProvider); + + parseCommandLine(args, servlet); // set up the Jersey servlet container for Jetty ServletHolder sh = new ServletHolder(containerClass); @@ -236,6 +294,7 @@ public class RESTServer implements Constants { filter = filter.trim(); context.addFilter(Class.forName(filter), "/*", 0); } + addCSRFFilter(context, conf); HttpServerUtil.constrainHttpMethods(context); // Put up info server. @@ -247,7 +306,6 @@ public class RESTServer implements Constants { infoServer.setAttribute("hbase.conf", conf); infoServer.start(); } - // start server server.start(); server.join(); diff --git a/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/client/Client.java b/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/client/Client.java index ebedf5776a1..142c2767a71 100644 --- a/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/client/Client.java +++ b/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/client/Client.java @@ -373,6 +373,20 @@ public class Client { return put(cluster, path, contentType, content); } + /** + * Send a PUT request + * @param path the path or URI + * @param contentType the content MIME type + * @param content the content bytes + * @param extraHdr extra Header to send + * @return a Response object with response detail + * @throws IOException + */ + public Response put(String path, String contentType, byte[] content, Header extraHdr) + throws IOException { + return put(cluster, path, contentType, content, extraHdr); + } + /** * Send a PUT request * @param cluster the cluster definition @@ -380,7 +394,7 @@ public class Client { * @param contentType the content MIME type * @param content the content bytes * @return a Response object with response detail - * @throws IOException + * @throws IOException for error */ public Response put(Cluster cluster, String path, String contentType, byte[] content) throws IOException { @@ -389,6 +403,27 @@ public class Client { return put(cluster, path, headers, content); } + /** + * Send a PUT request + * @param cluster the cluster definition + * @param path the path or URI + * @param contentType the content MIME type + * @param content the content bytes + * @param extraHdr additional Header to send + * @return a Response object with response detail + * @throws IOException for error + */ + public Response put(Cluster cluster, String path, String contentType, + byte[] content, Header extraHdr) throws IOException { + int cnt = extraHdr == null ? 1 : 2; + Header[] headers = new Header[cnt]; + headers[0] = new Header("Content-Type", contentType); + if (extraHdr != null) { + headers[1] = extraHdr; + } + return put(cluster, path, headers, content); + } + /** * Send a PUT request * @param path the path or URI @@ -440,6 +475,20 @@ public class Client { return post(cluster, path, contentType, content); } + /** + * Send a POST request + * @param path the path or URI + * @param contentType the content MIME type + * @param content the content bytes + * @param extraHdr additional Header to send + * @return a Response object with response detail + * @throws IOException + */ + public Response post(String path, String contentType, byte[] content, Header extraHdr) + throws IOException { + return post(cluster, path, contentType, content, extraHdr); + } + /** * Send a POST request * @param cluster the cluster definition @@ -447,7 +496,7 @@ public class Client { * @param contentType the content MIME type * @param content the content bytes * @return a Response object with response detail - * @throws IOException + * @throws IOException for error */ public Response post(Cluster cluster, String path, String contentType, byte[] content) throws IOException { @@ -456,6 +505,27 @@ public class Client { return post(cluster, path, headers, content); } + /** + * Send a POST request + * @param cluster the cluster definition + * @param path the path or URI + * @param contentType the content MIME type + * @param content the content bytes + * @param extraHdr additional Header to send + * @return a Response object with response detail + * @throws IOException for error + */ + public Response post(Cluster cluster, String path, String contentType, + byte[] content, Header extraHdr) throws IOException { + int cnt = extraHdr == null ? 1 : 2; + Header[] headers = new Header[cnt]; + headers[0] = new Header("Content-Type", contentType); + if (extraHdr != null) { + headers[1] = extraHdr; + } + return post(cluster, path, headers, content); + } + /** * Send a POST request * @param path the path or URI @@ -504,12 +574,23 @@ public class Client { return delete(cluster, path); } + /** + * Send a DELETE request + * @param path the path or URI + * @param extraHdr additional Header to send + * @return a Response object with response detail + * @throws IOException + */ + public Response delete(String path, Header extraHdr) throws IOException { + return delete(cluster, path, extraHdr); + } + /** * Send a DELETE request * @param cluster the cluster definition * @param path the path or URI * @return a Response object with response detail - * @throws IOException + * @throws IOException for error */ public Response delete(Cluster cluster, String path) throws IOException { DeleteMethod method = new DeleteMethod(); @@ -522,4 +603,24 @@ public class Client { method.releaseConnection(); } } + + /** + * Send a DELETE request + * @param cluster the cluster definition + * @param path the path or URI + * @return a Response object with response detail + * @throws IOException for error + */ + public Response delete(Cluster cluster, String path, Header extraHdr) throws IOException { + DeleteMethod method = new DeleteMethod(); + try { + Header[] headers = { extraHdr }; + int code = execute(cluster, method, headers, path); + headers = method.getResponseHeaders(); + byte[] content = method.getResponseBody(); + return new Response(code, headers, content); + } finally { + method.releaseConnection(); + } + } } diff --git a/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/filter/RestCsrfPreventionFilter.java b/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/filter/RestCsrfPreventionFilter.java new file mode 100644 index 00000000000..30eea95cc17 --- /dev/null +++ b/hbase-rest/src/main/java/org/apache/hadoop/hbase/rest/filter/RestCsrfPreventionFilter.java @@ -0,0 +1,287 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.rest.filter; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.conf.Configuration; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This filter provides protection against cross site request forgery (CSRF) + * attacks for REST APIs. Enabling this filter on an endpoint results in the + * requirement of all client to send a particular (configurable) HTTP header + * with every request. In the absense of this header the filter will reject the + * attempt as a bad request. + */ +@InterfaceAudience.Public +@InterfaceStability.Evolving +public class RestCsrfPreventionFilter implements Filter { + + private static final Logger LOG = + LoggerFactory.getLogger(RestCsrfPreventionFilter.class); + + public static final String HEADER_USER_AGENT = "User-Agent"; + public static final String BROWSER_USER_AGENT_PARAM = + "browser-useragents-regex"; + public static final String CUSTOM_HEADER_PARAM = "custom-header"; + public static final String CUSTOM_METHODS_TO_IGNORE_PARAM = + "methods-to-ignore"; + static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*"; + public static final String HEADER_DEFAULT = "X-XSRF-HEADER"; + static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE"; + private String headerName = HEADER_DEFAULT; + private Set methodsToIgnore = null; + private Set browserUserAgents; + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM); + if (customHeader != null) { + headerName = customHeader; + } + String customMethodsToIgnore = + filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM); + if (customMethodsToIgnore != null) { + parseMethodsToIgnore(customMethodsToIgnore); + } else { + parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT); + } + + String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM); + if (agents == null) { + agents = BROWSER_USER_AGENTS_DEFAULT; + } + parseBrowserUserAgents(agents); + LOG.info("Adding cross-site request forgery (CSRF) protection, " + + "headerName = {}, methodsToIgnore = {}, browserUserAgents = {}", + headerName, methodsToIgnore, browserUserAgents); + } + + void parseBrowserUserAgents(String userAgents) { + String[] agentsArray = userAgents.split(","); + browserUserAgents = new HashSet(); + for (String patternString : agentsArray) { + browserUserAgents.add(Pattern.compile(patternString)); + } + } + + void parseMethodsToIgnore(String mti) { + String[] methods = mti.split(","); + methodsToIgnore = new HashSet(); + for (int i = 0; i < methods.length; i++) { + methodsToIgnore.add(methods[i]); + } + } + + /** + * This method interrogates the User-Agent String and returns whether it + * refers to a browser. If its not a browser, then the requirement for the + * CSRF header will not be enforced; if it is a browser, the requirement will + * be enforced. + *

+ * A User-Agent String is considered to be a browser if it matches + * any of the regex patterns from browser-useragent-regex; the default + * behavior is to consider everything a browser that matches the following: + * "^Mozilla.*,^Opera.*". Subclasses can optionally override + * this method to use different behavior. + * + * @param userAgent The User-Agent String, or null if there isn't one + * @return true if the User-Agent String refers to a browser, false if not + */ + protected boolean isBrowser(String userAgent) { + if (userAgent == null) { + return false; + } + for (Pattern pattern : browserUserAgents) { + Matcher matcher = pattern.matcher(userAgent); + if (matcher.matches()) { + return true; + } + } + return false; + } + + /** + * Defines the minimal API requirements for the filter to execute its + * filtering logic. This interface exists to facilitate integration in + * components that do not run within a servlet container and therefore cannot + * rely on a servlet container to dispatch to the {@link #doFilter} method. + * Applications that do run inside a servlet container will not need to write + * code that uses this interface. Instead, they can use typical servlet + * container configuration mechanisms to insert the filter. + */ + public interface HttpInteraction { + + /** + * Returns the value of a header. + * + * @param header name of header + * @return value of header + */ + String getHeader(String header); + + /** + * Returns the method. + * + * @return method + */ + String getMethod(); + + /** + * Called by the filter after it decides that the request may proceed. + * + * @throws IOException if there is an I/O error + * @throws ServletException if the implementation relies on the servlet API + * and a servlet API call has failed + */ + void proceed() throws IOException, ServletException; + + /** + * Called by the filter after it decides that the request is a potential + * CSRF attack and therefore must be rejected. + * + * @param code status code to send + * @param message response message + * @throws IOException if there is an I/O error + */ + void sendError(int code, String message) throws IOException; + } + + /** + * Handles an {@link HttpInteraction} by applying the filtering logic. + * + * @param httpInteraction caller's HTTP interaction + * @throws IOException if there is an I/O error + * @throws ServletException if the implementation relies on the servlet API + * and a servlet API call has failed + */ + public void handleHttpInteraction(HttpInteraction httpInteraction) + throws IOException, ServletException { + if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) || + methodsToIgnore.contains(httpInteraction.getMethod()) || + httpInteraction.getHeader(headerName) != null) { + httpInteraction.proceed(); + } else { + httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST, + "Missing Required Header for CSRF Vulnerability Protection"); + } + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + final FilterChain chain) throws IOException, ServletException { + final HttpServletRequest httpRequest = (HttpServletRequest)request; + final HttpServletResponse httpResponse = (HttpServletResponse)response; + handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest, + httpResponse, chain)); + } + + @Override + public void destroy() { + } + + /** + * Constructs a mapping of configuration properties to be used for filter + * initialization. The mapping includes all properties that start with the + * specified configuration prefix. Property names in the mapping are trimmed + * to remove the configuration prefix. + * + * @param conf configuration to read + * @param confPrefix configuration prefix + * @return mapping of configuration properties to be used for filter + * initialization + */ + public static Map getFilterParams(Configuration conf, + String confPrefix) { + Map filterConfigMap = new HashMap<>(); + for (Map.Entry entry : conf) { + String name = entry.getKey(); + if (name.startsWith(confPrefix)) { + String value = conf.get(name); + name = name.substring(confPrefix.length()); + filterConfigMap.put(name, value); + } + } + return filterConfigMap; + } + + /** + * {@link HttpInteraction} implementation for use in the servlet filter. + */ + private static final class ServletFilterHttpInteraction + implements HttpInteraction { + + private final FilterChain chain; + private final HttpServletRequest httpRequest; + private final HttpServletResponse httpResponse; + + /** + * Creates a new ServletFilterHttpInteraction. + * + * @param httpRequest request to process + * @param httpResponse response to process + * @param chain filter chain to forward to if HTTP interaction is allowed + */ + public ServletFilterHttpInteraction(HttpServletRequest httpRequest, + HttpServletResponse httpResponse, FilterChain chain) { + this.httpRequest = httpRequest; + this.httpResponse = httpResponse; + this.chain = chain; + } + + @Override + public String getHeader(String header) { + return httpRequest.getHeader(header); + } + + @Override + public String getMethod() { + return httpRequest.getMethod(); + } + + @Override + public void proceed() throws IOException, ServletException { + chain.doFilter(httpRequest, httpResponse); + } + + @Override + public void sendError(int code, String message) throws IOException { + httpResponse.sendError(code, message); + } + } +} diff --git a/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/HBaseRESTTestingUtility.java b/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/HBaseRESTTestingUtility.java index 628b17cc47c..7c3e1fd9bba 100644 --- a/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/HBaseRESTTestingUtility.java +++ b/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/HBaseRESTTestingUtility.java @@ -75,6 +75,8 @@ public class HBaseRESTTestingUtility { filter = filter.trim(); context.addFilter(Class.forName(filter), "/*", 0); } + conf.set(RESTServer.REST_CSRF_BROWSER_USERAGENTS_REGEX_KEY, ".*"); + RESTServer.addCSRFFilter(context, conf); HttpServerUtil.constrainHttpMethods(context); LOG.info("Loaded filter classes :" + filterClasses); // start the server diff --git a/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/TestMultiRowResource.java b/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/TestMultiRowResource.java index 412ccdbf8d7..76c2f6192da 100644 --- a/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/TestMultiRowResource.java +++ b/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/TestMultiRowResource.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hbase.rest; +import org.apache.commons.httpclient.Header; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.*; import org.apache.hadoop.hbase.client.Admin; @@ -36,6 +37,8 @@ import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import javax.ws.rs.core.MediaType; import javax.xml.bind.JAXBContext; @@ -43,11 +46,15 @@ import javax.xml.bind.JAXBException; import javax.xml.bind.Marshaller; import javax.xml.bind.Unmarshaller; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import static org.junit.Assert.assertEquals; @Category(MediumTests.class) +@RunWith(Parameterized.class) public class TestMultiRowResource { private static final TableName TABLE = TableName.valueOf("TestRowResource"); @@ -70,10 +77,27 @@ public class TestMultiRowResource { private static Unmarshaller unmarshaller; private static Configuration conf; + private static Header extraHdr = null; + private static boolean csrfEnabled = true; + + @Parameterized.Parameters + public static Collection data() { + List params = new ArrayList(); + params.add(new Object[] {Boolean.TRUE}); + params.add(new Object[] {Boolean.FALSE}); + return params; + } + + public TestMultiRowResource(Boolean csrf) { + csrfEnabled = csrf; + } + @BeforeClass public static void setUpBeforeClass() throws Exception { conf = TEST_UTIL.getConfiguration(); + conf.setBoolean(RESTServer.REST_CSRF_ENABLED_KEY, csrfEnabled); + extraHdr = new Header(RESTServer.REST_CSRF_CUSTOM_HEADER_DEFAULT, ""); TEST_UTIL.startMiniCluster(); REST_TEST_UTIL.startServletContainer(conf); context = JAXBContext.newInstance( @@ -114,16 +138,21 @@ public class TestMultiRowResource { path.append("&row="); path.append(ROW_2); - client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1)); - client.post(row_6_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_2)); + if (csrfEnabled) { + Response response = client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1)); + assertEquals(400, response.getCode()); + } + + client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1), extraHdr); + client.post(row_6_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_2), extraHdr); Response response = client.get(path.toString(), Constants.MIMETYPE_JSON); assertEquals(response.getCode(), 200); assertEquals(Constants.MIMETYPE_JSON, response.getHeader("content-type")); - client.delete(row_5_url); - client.delete(row_6_url); + client.delete(row_5_url, extraHdr); + client.delete(row_6_url, extraHdr); } @@ -141,16 +170,16 @@ public class TestMultiRowResource { path.append("&row="); path.append(ROW_2); - client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1)); - client.post(row_6_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_2)); + client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1), extraHdr); + client.post(row_6_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_2), extraHdr); Response response = client.get(path.toString(), Constants.MIMETYPE_XML); assertEquals(response.getCode(), 200); assertEquals(Constants.MIMETYPE_XML, response.getHeader("content-type")); - client.delete(row_5_url); - client.delete(row_6_url); + client.delete(row_5_url, extraHdr); + client.delete(row_6_url, extraHdr); } @@ -166,7 +195,7 @@ public class TestMultiRowResource { path.append("&row="); path.append(ROW_2); - client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1)); + client.post(row_5_url, Constants.MIMETYPE_BINARY, Bytes.toBytes(VALUE_1), extraHdr); Response response = client.get(path.toString(), Constants.MIMETYPE_JSON); assertEquals(response.getCode(), 200); ObjectMapper mapper = new JacksonProvider().locateMapper(CellSetModel.class, @@ -175,7 +204,7 @@ public class TestMultiRowResource { assertEquals(1, cellSet.getRows().size()); assertEquals(ROW_1, Bytes.toString(cellSet.getRows().get(0).getKey())); assertEquals(VALUE_1, Bytes.toString(cellSet.getRows().get(0).getCells().get(0).getValue())); - client.delete(row_5_url); + client.delete(row_5_url, extraHdr); } } diff --git a/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/TestSchemaResource.java b/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/TestSchemaResource.java index f3891641796..7ca56f363f5 100644 --- a/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/TestSchemaResource.java +++ b/hbase-rest/src/test/java/org/apache/hadoop/hbase/rest/TestSchemaResource.java @@ -22,10 +22,15 @@ package org.apache.hadoop.hbase.rest; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.StringWriter; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBException; +import org.apache.commons.httpclient.Header; + import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.HBaseTestingUtility; import org.apache.hadoop.hbase.testclassification.MediumTests; @@ -45,8 +50,11 @@ import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; @Category(MediumTests.class) +@RunWith(Parameterized.class) public class TestSchemaResource { private static String TABLE1 = "TestSchemaResource1"; private static String TABLE2 = "TestSchemaResource2"; @@ -58,10 +66,27 @@ public class TestSchemaResource { private static JAXBContext context; private static Configuration conf; private static TestTableSchemaModel testTableSchemaModel; + private static Header extraHdr = null; + + private static boolean csrfEnabled = true; + + @Parameterized.Parameters + public static Collection data() { + List params = new ArrayList(); + params.add(new Object[] {Boolean.TRUE}); + params.add(new Object[] {Boolean.FALSE}); + return params; + } + + public TestSchemaResource(Boolean csrf) { + csrfEnabled = csrf; + } @BeforeClass public static void setUpBeforeClass() throws Exception { conf = TEST_UTIL.getConfiguration(); + conf.setBoolean(RESTServer.REST_CSRF_ENABLED_KEY, csrfEnabled); + extraHdr = new Header(RESTServer.REST_CSRF_CUSTOM_HEADER_DEFAULT, ""); TEST_UTIL.startMiniCluster(); REST_TEST_UTIL.startServletContainer(conf); client = new Client(new Cluster().add("localhost", @@ -102,12 +127,18 @@ public class TestSchemaResource { // create the table model = testTableSchemaModel.buildTestModel(TABLE1); testTableSchemaModel.checkModel(model, TABLE1); - response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model)); + if (csrfEnabled) { + // test put operation is forbidden without custom header + response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model)); + assertEquals(response.getCode(), 400); + } + + response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model), extraHdr); assertEquals(response.getCode(), 201); // recall the same put operation but in read-only mode conf.set("hbase.rest.readonly", "true"); - response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model)); + response = client.put(schemaPath, Constants.MIMETYPE_XML, toXML(model), extraHdr); assertEquals(response.getCode(), 403); // retrieve the schema and validate it @@ -124,15 +155,21 @@ public class TestSchemaResource { model = testTableSchemaModel.fromJSON(Bytes.toString(response.getBody())); testTableSchemaModel.checkModel(model, TABLE1); + if (csrfEnabled) { + // test delete schema operation is forbidden without custom header + response = client.delete(schemaPath); + assertEquals(400, response.getCode()); + } + // test delete schema operation is forbidden in read-only mode - response = client.delete(schemaPath); + response = client.delete(schemaPath, extraHdr); assertEquals(response.getCode(), 403); // return read-only setting back to default conf.set("hbase.rest.readonly", "false"); // delete the table and make sure HBase concurs - response = client.delete(schemaPath); + response = client.delete(schemaPath, extraHdr); assertEquals(response.getCode(), 200); assertFalse(admin.tableExists(TableName.valueOf(TABLE1))); } @@ -149,14 +186,21 @@ public class TestSchemaResource { // create the table model = testTableSchemaModel.buildTestModel(TABLE2); testTableSchemaModel.checkModel(model, TABLE2); + + if (csrfEnabled) { + // test put operation is forbidden without custom header + response = client.put(schemaPath, Constants.MIMETYPE_PROTOBUF, model.createProtobufOutput()); + assertEquals(response.getCode(), 400); + } response = client.put(schemaPath, Constants.MIMETYPE_PROTOBUF, - model.createProtobufOutput()); + model.createProtobufOutput(), extraHdr); assertEquals(response.getCode(), 201); // recall the same put operation but in read-only mode conf.set("hbase.rest.readonly", "true"); response = client.put(schemaPath, Constants.MIMETYPE_PROTOBUF, - model.createProtobufOutput()); + model.createProtobufOutput(), extraHdr); + assertNotNull(extraHdr); assertEquals(response.getCode(), 403); // retrieve the schema and validate it @@ -175,15 +219,21 @@ public class TestSchemaResource { model.getObjectFromMessage(response.getBody()); testTableSchemaModel.checkModel(model, TABLE2); + if (csrfEnabled) { + // test delete schema operation is forbidden without custom header + response = client.delete(schemaPath); + assertEquals(400, response.getCode()); + } + // test delete schema operation is forbidden in read-only mode - response = client.delete(schemaPath); + response = client.delete(schemaPath, extraHdr); assertEquals(response.getCode(), 403); // return read-only setting back to default conf.set("hbase.rest.readonly", "false"); // delete the table and make sure HBase concurs - response = client.delete(schemaPath); + response = client.delete(schemaPath, extraHdr); assertEquals(response.getCode(), 200); assertFalse(admin.tableExists(TableName.valueOf(TABLE2))); }