diff --git a/hadoop-common-project/hadoop-common/CHANGES.txt b/hadoop-common-project/hadoop-common/CHANGES.txt index a117d50db3e..4ba91a2b486 100644 --- a/hadoop-common-project/hadoop-common/CHANGES.txt +++ b/hadoop-common-project/hadoop-common/CHANGES.txt @@ -731,6 +731,9 @@ Release 2.8.0 - UNRELEASED HADOOP-11262. Enable YARN to use S3A. (Pieter Reuse via lei) + HADOOP-12691. Add CSRF Filter for REST APIs to Hadoop Common. + (Larry McCay via cnauroth) + IMPROVEMENTS HADOOP-12458. Retries is typoed to spell Retires in parts of diff --git a/hadoop-common/src/main/java/org/apache/hadoop/security/http/RestCsrfPreventionFilter.java b/hadoop-common/src/main/java/org/apache/hadoop/security/http/RestCsrfPreventionFilter.java new file mode 100644 index 00000000000..50f95adea70 --- /dev/null +++ b/hadoop-common/src/main/java/org/apache/hadoop/security/http/RestCsrfPreventionFilter.java @@ -0,0 +1,89 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.security.http; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * This filter provides protection against cross site request forgery (CSRF) + * attacks for REST APIs. Enabling this filter on an endpoint results in the + * requirement of all client to send a particular (configurable) HTTP header + * with every request. In the absense of this header the filter will reject the + * attempt as a bad request. + */ +public class RestCsrfPreventionFilter implements Filter { + public static final String CUSTOM_HEADER_PARAM = "custom-header"; + public static final String CUSTOM_METHODS_TO_IGNORE_PARAM = + "methods-to-ignore"; + static final String HEADER_DEFAULT = "X-XSRF-HEADER"; + static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE"; + private String headerName = HEADER_DEFAULT; + private Set methodsToIgnore = null; + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM); + if (customHeader != null) { + headerName = customHeader; + } + String customMethodsToIgnore = + filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM); + if (customMethodsToIgnore != null) { + parseMethodsToIgnore(customMethodsToIgnore); + } else { + parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT); + } + } + + void parseMethodsToIgnore(String mti) { + String[] methods = mti.split(","); + methodsToIgnore = new HashSet(); + for (int i = 0; i < methods.length; i++) { + methodsToIgnore.add(methods[i]); + } + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + HttpServletRequest httpRequest = (HttpServletRequest)request; + if (methodsToIgnore.contains(httpRequest.getMethod()) || + httpRequest.getHeader(headerName) != null) { + chain.doFilter(request, response); + } else { + ((HttpServletResponse)response).sendError( + HttpServletResponse.SC_BAD_REQUEST, + "Missing Required Header for Vulnerability Protection"); + } + } + + @Override + public void destroy() { + } +} diff --git a/hadoop-common/src/test/java/org/apache/hadoop/security/http/TestRestCsrfPreventionFilter.java b/hadoop-common/src/test/java/org/apache/hadoop/security/http/TestRestCsrfPreventionFilter.java new file mode 100644 index 00000000000..adf89f5b44d --- /dev/null +++ b/hadoop-common/src/test/java/org/apache/hadoop/security/http/TestRestCsrfPreventionFilter.java @@ -0,0 +1,276 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.security.http; + +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Test; +import org.mockito.Mockito; + +public class TestRestCsrfPreventionFilter { + + private static final String EXPECTED_MESSAGE = + "Missing Required Header for Vulnerability Protection"; + private static final String X_CUSTOM_HEADER = "X-CUSTOM_HEADER"; + + @Test + public void testNoHeaderDefaultConfig_badRequest() + throws ServletException, IOException { + // Setup the configuration settings of the server + FilterConfig filterConfig = Mockito.mock(FilterConfig.class); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_HEADER_PARAM)).thenReturn(null); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_METHODS_TO_IGNORE_PARAM)). + thenReturn(null); + + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(RestCsrfPreventionFilter.HEADER_DEFAULT)). + thenReturn(null); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + RestCsrfPreventionFilter filter = new RestCsrfPreventionFilter(); + filter.init(filterConfig); + filter.doFilter(mockReq, mockRes, mockChain); + + verify(mockRes, atLeastOnce()).sendError( + HttpServletResponse.SC_BAD_REQUEST, EXPECTED_MESSAGE); + Mockito.verifyZeroInteractions(mockChain); + } + + @Test + public void testHeaderPresentDefaultConfig_goodRequest() + throws ServletException, IOException { + // Setup the configuration settings of the server + FilterConfig filterConfig = Mockito.mock(FilterConfig.class); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_HEADER_PARAM)).thenReturn(null); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_METHODS_TO_IGNORE_PARAM)). + thenReturn(null); + + // CSRF HAS been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(RestCsrfPreventionFilter.HEADER_DEFAULT)). + thenReturn("valueUnimportant"); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + RestCsrfPreventionFilter filter = new RestCsrfPreventionFilter(); + filter.init(filterConfig); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verify(mockChain).doFilter(mockReq, mockRes); + } + + @Test + public void testHeaderPresentCustomHeaderConfig_goodRequest() + throws ServletException, IOException { + // Setup the configuration settings of the server + FilterConfig filterConfig = Mockito.mock(FilterConfig.class); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_HEADER_PARAM)). + thenReturn(X_CUSTOM_HEADER); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_METHODS_TO_IGNORE_PARAM)). + thenReturn(null); + + // CSRF HAS been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(X_CUSTOM_HEADER)). + thenReturn("valueUnimportant"); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + RestCsrfPreventionFilter filter = new RestCsrfPreventionFilter(); + filter.init(filterConfig); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verify(mockChain).doFilter(mockReq, mockRes); + } + + @Test + public void testMissingHeaderWithCustomHeaderConfig_badRequest() + throws ServletException, IOException { + // Setup the configuration settings of the server + FilterConfig filterConfig = Mockito.mock(FilterConfig.class); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_HEADER_PARAM)). + thenReturn(X_CUSTOM_HEADER); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_METHODS_TO_IGNORE_PARAM)). + thenReturn(null); + + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(RestCsrfPreventionFilter.HEADER_DEFAULT)). + thenReturn(null); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + RestCsrfPreventionFilter filter = new RestCsrfPreventionFilter(); + filter.init(filterConfig); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verifyZeroInteractions(mockChain); + } + + @Test + public void testMissingHeaderNoMethodsToIgnoreConfig_badRequest() + throws ServletException, IOException { + // Setup the configuration settings of the server + FilterConfig filterConfig = Mockito.mock(FilterConfig.class); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_HEADER_PARAM)).thenReturn(null); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_METHODS_TO_IGNORE_PARAM)). + thenReturn(""); + + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(RestCsrfPreventionFilter.HEADER_DEFAULT)). + thenReturn(null); + Mockito.when(mockReq.getMethod()). + thenReturn("GET"); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + RestCsrfPreventionFilter filter = new RestCsrfPreventionFilter(); + filter.init(filterConfig); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verifyZeroInteractions(mockChain); + } + + @Test + public void testMissingHeaderIgnoreGETMethodConfig_goodRequest() + throws ServletException, IOException { + // Setup the configuration settings of the server + FilterConfig filterConfig = Mockito.mock(FilterConfig.class); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_HEADER_PARAM)).thenReturn(null); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_METHODS_TO_IGNORE_PARAM)). + thenReturn("GET"); + + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(RestCsrfPreventionFilter.HEADER_DEFAULT)). + thenReturn(null); + Mockito.when(mockReq.getMethod()). + thenReturn("GET"); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + RestCsrfPreventionFilter filter = new RestCsrfPreventionFilter(); + filter.init(filterConfig); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verify(mockChain).doFilter(mockReq, mockRes); + } + + @Test + public void testMissingHeaderMultipleIgnoreMethodsConfig_goodRequest() + throws ServletException, IOException { + // Setup the configuration settings of the server + FilterConfig filterConfig = Mockito.mock(FilterConfig.class); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_HEADER_PARAM)).thenReturn(null); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_METHODS_TO_IGNORE_PARAM)). + thenReturn("GET,OPTIONS"); + + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(RestCsrfPreventionFilter.HEADER_DEFAULT)). + thenReturn(null); + Mockito.when(mockReq.getMethod()). + thenReturn("OPTIONS"); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + RestCsrfPreventionFilter filter = new RestCsrfPreventionFilter(); + filter.init(filterConfig); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verify(mockChain).doFilter(mockReq, mockRes); + } + + @Test + public void testMissingHeaderMultipleIgnoreMethodsConfig_badRequest() + throws ServletException, IOException { + // Setup the configuration settings of the server + FilterConfig filterConfig = Mockito.mock(FilterConfig.class); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_HEADER_PARAM)).thenReturn(null); + Mockito.when(filterConfig.getInitParameter( + RestCsrfPreventionFilter.CUSTOM_METHODS_TO_IGNORE_PARAM)). + thenReturn("GET,OPTIONS"); + + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(RestCsrfPreventionFilter.HEADER_DEFAULT)). + thenReturn(null); + Mockito.when(mockReq.getMethod()). + thenReturn("PUT"); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + RestCsrfPreventionFilter filter = new RestCsrfPreventionFilter(); + filter.init(filterConfig); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verifyZeroInteractions(mockChain); + } +}