NIFI-11767 Refactored Groovy tests in nifi-web-error and nifi-web-security to Java

This closes #7457

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
dan-s1 2023-07-03 17:46:35 +00:00 committed by exceptionfactory
parent b3372900b3
commit d24318cdb8
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
7 changed files with 580 additions and 811 deletions

View File

@ -1,135 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.filter
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.servlet.FilterChain
import javax.servlet.FilterConfig
import javax.servlet.RequestDispatcher
import javax.servlet.ServletContext
import javax.servlet.ServletRequest
import javax.servlet.ServletResponse
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import static org.junit.jupiter.api.Assertions.assertEquals
class CatchAllFilterTest {
private static final Logger logger = LoggerFactory.getLogger(CatchAllFilterTest.class)
@BeforeAll
static void setUpOnce() throws Exception {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
private static String getValue(String parameterName, Map<String, String> params = [:]) {
params.containsKey(parameterName) ? params[parameterName] : ""
}
@Test
void testInitShouldCallSuper() {
// Arrange
def EXPECTED_ALLOWED_CONTEXT_PATHS = ["/path1", "/path2"].join(", ")
def parameters = [allowedContextPaths: EXPECTED_ALLOWED_CONTEXT_PATHS]
FilterConfig mockFilterConfig = [
getInitParameter : { String parameterName ->
return getValue(parameterName, parameters)
},
getServletContext: { ->
[getInitParameter: { String parameterName ->
return getValue(parameterName, parameters)
}] as ServletContext
}] as FilterConfig
CatchAllFilter caf = new CatchAllFilter()
// Act
caf.init(mockFilterConfig)
logger.info("Allowed context paths: ${caf.getAllowedContextPaths()}")
// Assert
assertEquals(EXPECTED_ALLOWED_CONTEXT_PATHS, caf.getAllowedContextPaths())
}
@Test
void testShouldDoFilter() {
// Arrange
final String EXPECTED_ALLOWED_CONTEXT_PATHS = ["/path1", "/path2"].join(", ")
final String EXPECTED_FORWARD_PATH = "index.jsp"
final Map PARAMETERS = [
allowedContextPaths: EXPECTED_ALLOWED_CONTEXT_PATHS,
forwardPath : EXPECTED_FORWARD_PATH
]
final String EXPECTED_CONTEXT_PATH = ""
// Mock collaborators
FilterConfig mockFilterConfig = [
getInitParameter : { String parameterName ->
return getValue(parameterName, PARAMETERS)
},
getServletContext: { ->
[getInitParameter: { String parameterName ->
return getValue(parameterName, PARAMETERS)
}] as ServletContext
}] as FilterConfig
// Local map to store request attributes
def requestAttributes = [:]
// Local string to store resulting path
String forwardedRequestTo = ""
final Map HEADERS = [
"X-ProxyContextPath" : "",
"X-Forwarded-Context": "",
"X-Forwarded-Prefix" : ""]
HttpServletRequest mockRequest = [
getContextPath : { -> EXPECTED_CONTEXT_PATH },
getHeader : { String headerName -> getValue(headerName, HEADERS) },
setAttribute : { String attr, String value ->
requestAttributes[attr] = value
logger.mock("Set request attribute ${attr} to ${value}")
},
getRequestDispatcher: { String path ->
[forward: { ServletRequest request, ServletResponse response ->
forwardedRequestTo = path
logger.mock("Forwarded request to ${path}")
}] as RequestDispatcher
}] as HttpServletRequest
HttpServletResponse mockResponse = [:] as HttpServletResponse
FilterChain mockFilterChain = [:] as FilterChain
CatchAllFilter caf = new CatchAllFilter()
caf.init(mockFilterConfig)
logger.info("Allowed context paths: ${caf.getAllowedContextPaths()}")
// Act
caf.doFilter(mockRequest, mockResponse, mockFilterChain)
// Assert
assertEquals(EXPECTED_FORWARD_PATH, forwardedRequestTo)
}
}

View File

@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.filter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class CatchAllFilterTest {
@Mock
private ServletContext servletContext;
@Mock
private FilterConfig filterConfig;
@BeforeEach
public void setUp() {
when(filterConfig.getServletContext()).thenReturn(servletContext);
}
@Test
public void testInitShouldCallSuper() throws ServletException {
String expectedAllowedContextPaths = getExpectedAllowedContextPaths();
final Map<String, String> parameters = Collections.singletonMap("allowedContextPaths", getExpectedAllowedContextPaths());
when(servletContext.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters));
when(filterConfig.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters));
CatchAllFilter catchAllFilter = new CatchAllFilter();
catchAllFilter.init(filterConfig);
assertEquals(expectedAllowedContextPaths, catchAllFilter.getAllowedContextPaths());
}
@Test
public void testShouldDoFilter(@Mock HttpServletRequest request, @Mock RequestDispatcher requestDispatcher,
@Mock HttpServletResponse response, @Mock FilterChain filterChain ) throws ServletException, IOException {
final String expectedAllowedContextPaths = getExpectedAllowedContextPaths();
final String expectedForwardPath = "index.jsp";
final Map<String, String> parameters = new HashMap<>();
parameters.put("allowedContextPaths", expectedAllowedContextPaths);
parameters.put("forwardPath", expectedForwardPath);
final Map<String, Object> requestAttributes = new HashMap<>();
final String[] forwardedRequestTo = new String[1];
final Map<String, String> headers = new HashMap<>();
headers.put("X-ProxyContextPath", "");
headers.put("X-Forwarded-Context", "");
headers.put("X-Forwarded-Prefix", "");
when(servletContext.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters));
when(filterConfig.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters));
when(request.getHeader(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), headers));
doAnswer(invocation -> requestAttributes.put(invocation.getArgument(0), invocation.getArgument(1))).when(request).setAttribute(anyString(), any());
when(request.getRequestDispatcher(anyString())).thenAnswer(outerInvocation -> {
doAnswer(innerInvocation -> forwardedRequestTo[0] = outerInvocation.getArgument(0)).when(requestDispatcher).forward(any(), any());
return requestDispatcher;});
CatchAllFilter catchAllFilter = new CatchAllFilter();
catchAllFilter.init(filterConfig);
catchAllFilter.doFilter(request, response, filterChain);
assertEquals(expectedForwardPath, forwardedRequestTo[0]);
}
private String getExpectedAllowedContextPaths() {
return String.join(",", "/path1", "/path2");
}
private static String getValue(String parameterName, Map<String, String> params) {
return params.getOrDefault(parameterName, "");
}
}

View File

@ -251,11 +251,5 @@
<artifactId>jetty-servlet</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.codehaus.groovy</groupId>
<artifactId>groovy-json</artifactId>
<version>${nifi.groovy.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -1,393 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security
import org.apache.nifi.authorization.user.NiFiUser
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
class ProxiedEntitiesUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(ProxiedEntitiesUtils.class)
private static final String SAFE_USER_NAME_JOHN = "jdoe"
private static final String SAFE_USER_DN_JOHN = "CN=${SAFE_USER_NAME_JOHN}, OU=Apache NiFi"
private static final String SAFE_USER_NAME_PROXY_1 = "proxy1.nifi.apache.org"
private static final String SAFE_USER_DN_PROXY_1 = "CN=${SAFE_USER_NAME_PROXY_1}, OU=Apache NiFi"
private static final String SAFE_USER_NAME_PROXY_2 = "proxy2.nifi.apache.org"
private static final String SAFE_USER_DN_PROXY_2 = "CN=${SAFE_USER_NAME_PROXY_2}, OU=Apache NiFi"
private static
final String MALICIOUS_USER_NAME_JOHN = "${SAFE_USER_NAME_JOHN}, OU=Apache NiFi><CN=${SAFE_USER_NAME_PROXY_1}"
private static final String MALICIOUS_USER_DN_JOHN = "CN=${MALICIOUS_USER_NAME_JOHN}, OU=Apache NiFi"
private static
final String MALICIOUS_USER_NAME_JOHN_ESCAPED = sanitizeDn(MALICIOUS_USER_NAME_JOHN)
private static final String UNICODE_DN_1 = "CN=Алйс, OU=Apache NiFi"
private static final String UNICODE_DN_1_ENCODED = "<" + base64Encode(UNICODE_DN_1) + ">"
private static final String UNICODE_DN_2 = "CN=Боб, OU=Apache NiFi"
private static final String UNICODE_DN_2_ENCODED = "<" + base64Encode(UNICODE_DN_2) + ">"
@BeforeAll
static void setUpOnce() throws Exception {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
private static String sanitizeDn(String dn = "") {
dn.replaceAll(/>/, '\\\\>').replaceAll('<', '\\\\<')
}
private static String base64Encode(String dn = "") {
return Base64.getEncoder().encodeToString(dn.getBytes(StandardCharsets.UTF_8))
}
private static String printUnicodeString(final String raw) {
StringBuilder sb = new StringBuilder()
for (int i = 0; i < raw.size(); i++) {
int codePoint = Character.codePointAt(raw, i)
int charCount = Character.charCount(codePoint)
if (charCount > 1) {
i += charCount - 1 // 2.
if (i >= raw.length()) {
throw new IllegalArgumentException("Code point indicated more characters than available")
}
}
sb.append(String.format("\\u%04x ", codePoint))
}
return sb.toString().trim()
}
@Test
void testSanitizeDnShouldHandleFuzzing() throws Exception {
// Arrange
final String DESIRED_NAME = SAFE_USER_NAME_JOHN
logger.info(" Desired name: ${DESIRED_NAME} | ${printUnicodeString(DESIRED_NAME)}")
// Contains various attempted >< escapes, trailing NULL, and BACKSPACE + 'n'
final List MALICIOUS_NAMES = [MALICIOUS_USER_NAME_JOHN,
SAFE_USER_NAME_JOHN + ">",
SAFE_USER_NAME_JOHN + "><>",
SAFE_USER_NAME_JOHN + "\\>",
SAFE_USER_NAME_JOHN + "\u003e",
SAFE_USER_NAME_JOHN + "\u005c\u005c\u003e",
SAFE_USER_NAME_JOHN + "\u0000",
SAFE_USER_NAME_JOHN + "\u0008n"]
// Act
MALICIOUS_NAMES.each { String name ->
logger.info(" Raw name: ${name} | ${printUnicodeString(name)}")
String sanitizedName = ProxiedEntitiesUtils.sanitizeDn(name)
logger.info("Sanitized name: ${sanitizedName} | ${printUnicodeString(sanitizedName)}")
// Assert
assertNotEquals(DESIRED_NAME, sanitizedName)
}
}
@Test
void testShouldFormatProxyDn() throws Exception {
// Arrange
final String DN = SAFE_USER_DN_JOHN
logger.info(" Provided proxy DN: ${DN}")
final String EXPECTED_PROXY_DN = "<${DN}>"
logger.info(" Expected proxy DN: ${EXPECTED_PROXY_DN}")
// Act
String forjohnedProxyDn = ProxiedEntitiesUtils.formatProxyDn(DN)
logger.info("Forjohned proxy DN: ${forjohnedProxyDn}")
// Assert
assertEquals(EXPECTED_PROXY_DN, forjohnedProxyDn)
}
@Test
void testFormatProxyDnShouldHandleMaliciousInput() throws Exception {
// Arrange
final String DN = MALICIOUS_USER_DN_JOHN
logger.info(" Provided proxy DN: ${DN}")
final String SANITIZED_DN = sanitizeDn(DN)
final String EXPECTED_PROXY_DN = "<${SANITIZED_DN}>"
logger.info(" Expected proxy DN: ${EXPECTED_PROXY_DN}")
// Act
String forjohnedProxyDn = ProxiedEntitiesUtils.formatProxyDn(DN)
logger.info("Forjohned proxy DN: ${forjohnedProxyDn}")
// Assert
assertEquals(EXPECTED_PROXY_DN, forjohnedProxyDn)
}
@Test
void testGetProxiedEntitiesChain() throws Exception {
// Arrange
String[] input = [SAFE_USER_NAME_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2]
final String expectedOutput = "<${SAFE_USER_NAME_JOHN}><${SAFE_USER_DN_PROXY_1}><${SAFE_USER_DN_PROXY_2}>"
// Act
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
// Assert
assertEquals(expectedOutput, output)
}
@Test
void testGetProxiedEntitiesChainShouldHandleMaliciousInput() throws Exception {
// Arrange
String[] input = [MALICIOUS_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2]
final String expectedOutput = "<${sanitizeDn(MALICIOUS_USER_DN_JOHN)}><${SAFE_USER_DN_PROXY_1}><${SAFE_USER_DN_PROXY_2}>"
// Act
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
// Assert
assertEquals(expectedOutput, output)
}
@Test
void testGetProxiedEntitiesChainShouldEncodeUnicode() throws Exception {
// Arrange
String[] input = [SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2]
final String expectedOutput = "<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}><${UNICODE_DN_2_ENCODED}>"
// Act
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
// Assert
assertEquals(expectedOutput, output)
}
@Test
void testFormatProxyDnShouldEncodeNonAsciiCharacters() throws Exception {
// Arrange
logger.info(" Provided DN: ${UNICODE_DN_1}")
final String expectedFormattedDn = "<${UNICODE_DN_1_ENCODED}>"
logger.info(" Expected DN: expected")
// Act
String formattedDn = ProxiedEntitiesUtils.formatProxyDn(UNICODE_DN_1)
logger.info("Formatted DN: ${formattedDn}")
// Assert
assertEquals(expectedFormattedDn, formattedDn)
}
@Test
void testShouldBuildProxyChain() throws Exception {
// Arrange
def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser
def mockJohn = [getIdentity: { -> SAFE_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser
// Act
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn)
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assertEquals("<${SAFE_USER_NAME_JOHN}><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
}
@Test
void testBuildProxyChainFromNullUserShouldBeAnonymous() throws Exception {
// Arrange
// Act
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(null)
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assertEquals("<>", proxiedEntitiesChain)
}
@Test
void testBuildProxyChainFromAnonymousUserShouldBeAnonymous() throws Exception {
// Arrange
def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser
def mockAnonymous = [getIdentity: { -> "anonymous" }, getChain: { -> mockProxy1 }, isAnonymous: { -> true}] as NiFiUser
// Act
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockAnonymous)
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assertEquals("<><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
}
@Test
void testBuildProxyChainShouldHandleUnicode() throws Exception {
// Arrange
def mockProxy1 = [getIdentity: { -> UNICODE_DN_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser
def mockJohn = [getIdentity: { -> SAFE_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser
// Act
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn)
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assertEquals("<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}>" as String, proxiedEntitiesChain)
}
@Test
void testBuildProxyChainShouldHandleMaliciousUser() throws Exception {
// Arrange
def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser
def mockJohn = [getIdentity: { -> MALICIOUS_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser
// Act
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn)
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
// Assert
assertEquals("<${MALICIOUS_USER_NAME_JOHN_ESCAPED}><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
}
@Test
void testShouldTokenizeProxiedEntitiesChainWithUserNames() throws Exception {
// Arrange
final List NAMES = [SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2]
final String RAW_PROXY_CHAIN = "<${NAMES.join("><")}>"
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
// Act
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assertEquals(NAMES, tokenizedNames)
}
@Test
void testShouldTokenizeAnonymous() throws Exception {
// Arrange
final List NAMES = [""]
final String RAW_PROXY_CHAIN = "<>"
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
// Act
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assertEquals(NAMES, tokenizedNames)
}
@Test
void testShouldTokenizeDoubleAnonymous() throws Exception {
// Arrange
final List NAMES = ["", ""]
final String RAW_PROXY_CHAIN = "<><>"
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
// Act
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assertEquals(NAMES, tokenizedNames)
}
@Test
void testShouldTokenizeNestedAnonymous() throws Exception {
// Arrange
final List NAMES = [SAFE_USER_DN_PROXY_1, "", SAFE_USER_DN_PROXY_2]
final String RAW_PROXY_CHAIN = "<${SAFE_USER_DN_PROXY_1}><><${SAFE_USER_DN_PROXY_2}>"
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
// Act
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assertEquals(NAMES, tokenizedNames)
}
@Test
void testShouldTokenizeProxiedEntitiesChainWithDNs() throws Exception {
// Arrange
final List DNS = [SAFE_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2]
final String RAW_PROXY_CHAIN = "<${DNS.join("><")}>"
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
// Act
def tokenizedDns = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
logger.info("Tokenized proxy chain: ${tokenizedDns.collect { "\"${it}\"" }}")
// Assert
assertEquals(DNS, tokenizedDns)
}
@Test
void testShouldTokenizeProxiedEntitiesChainWithAnonymousUser() throws Exception {
// Arrange
final List NAMES = ["", SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2]
final String RAW_PROXY_CHAIN = "<${NAMES.join("><")}>"
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
// Act
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
logger.info("Tokenized proxy chain: ${tokenizedNames}")
// Assert
assertEquals(NAMES, tokenizedNames)
}
@Test
void testTokenizeProxiedEntitiesChainShouldHandleMaliciousUser() throws Exception {
// Arrange
final List NAMES = [MALICIOUS_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2]
final String RAW_PROXY_CHAIN = "<${NAMES.collect { sanitizeDn(it) }.join("><")}>"
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
// Act
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
logger.info("Tokenized proxy chain: ${tokenizedNames.collect { "\"${it}\"" }}")
// Assert
assertEquals(NAMES, tokenizedNames)
assertEquals(NAMES.size(), tokenizedNames.size())
assertFalse(tokenizedNames.contains(SAFE_USER_NAME_JOHN))
}
@Test
void testTokenizeProxiedEntitiesChainShouldDecodeNonAsciiValues() throws Exception {
// Arrange
final String RAW_PROXY_CHAIN = "<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}><${UNICODE_DN_2_ENCODED}>"
final List TOKENIZED_NAMES = [SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2]
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
// Act
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
logger.info("Tokenized proxy chain: ${tokenizedNames.collect { "\"${it}\"" }}")
// Assert
assertEquals(TOKENIZED_NAMES, tokenizedNames)
assertEquals(TOKENIZED_NAMES.size(), tokenizedNames.size())
}
}

View File

@ -1,277 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security.requests
import org.apache.commons.lang3.StringUtils
import org.apache.nifi.stream.io.StreamUtils
import org.eclipse.jetty.server.LocalConnector
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.FilterHolder
import org.eclipse.jetty.servlet.ServletContextHandler
import org.eclipse.jetty.servlet.ServletHolder
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.servlet.DispatcherType
import javax.servlet.ServletException
import javax.servlet.ServletInputStream
import javax.servlet.http.HttpServlet
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import java.util.concurrent.TimeUnit
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertTrue
class ContentLengthFilterTest {
private static final Logger logger = LoggerFactory.getLogger(ContentLengthFilterTest.class)
private static final int MAX_CONTENT_LENGTH = 1000
private static final int SERVER_IDLE_TIMEOUT = 2500 // only one request needed + value large enough for slow systems
private static final String POST_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s"
private static final String FORM_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\nContent-Type: application/x-www-form-urlencoded\r\nAccept-Charset: UTF-8\r\n\r\n%s"
public static final int FORM_CONTENT_SIZE = 128
// These variables hold data for content small enough to be allowed
private static final int SMALL_CLAIM_SIZE_BYTES = 150
private static final String SMALL_PAYLOAD = "1" * SMALL_CLAIM_SIZE_BYTES
// These variables hold data for content too large to be allowed
private static final int LARGE_CLAIM_SIZE_BYTES = 2000
private static final String LARGE_PAYLOAD = "1" * LARGE_CLAIM_SIZE_BYTES
private Server serverUnderTest
private LocalConnector localConnector
private ServletContextHandler contextUnderTest
@BeforeEach
void setUp() {
createSimpleReadServer()
}
@AfterEach
void tearDown() {
stopServer()
}
void stopServer() throws Exception {
if (serverUnderTest && serverUnderTest.isRunning()) {
serverUnderTest.stop()
}
}
private void configureAndStartServer(HttpServlet servlet, int maxFormContentSize) throws Exception {
serverUnderTest = new Server()
localConnector = new LocalConnector(serverUnderTest)
localConnector.setIdleTimeout(SERVER_IDLE_TIMEOUT)
serverUnderTest.addConnector(localConnector)
contextUnderTest = new ServletContextHandler(serverUnderTest, "/")
if (maxFormContentSize > 0) {
contextUnderTest.setMaxFormContentSize(maxFormContentSize)
}
contextUnderTest.addServlet(new ServletHolder(servlet), "/*")
// This only adds the ContentLengthFilter if a valid maxFormContentSize is not provided
if (maxFormContentSize < 0) {
FilterHolder holder = contextUnderTest.addFilter(ContentLengthFilter.class, "/*", EnumSet.of(DispatcherType.REQUEST) as EnumSet<DispatcherType>)
holder.setInitParameter(ContentLengthFilter.MAX_LENGTH_INIT_PARAM, String.valueOf(MAX_CONTENT_LENGTH))
}
serverUnderTest.start()
}
/**
* Initializes a server which consumes any provided request input stream and returns HTTP 200. It has no
* {@code maxFormContentSize}, so the {@link ContentLengthFilter} is applied. The response contains a header and the
* response body indicating the total number of request content bytes read.
*
* @throws Exception if there is a problem setting up the server
*/
private void createSimpleReadServer() throws Exception {
HttpServlet mockServlet = [
doPost: { HttpServletRequest req, HttpServletResponse resp ->
byte[] byteBuffer = new byte[2048]
int bytesRead = StreamUtils.fillBuffer(req.getInputStream(), byteBuffer, false)
resp.setHeader("Bytes-Read", bytesRead as String)
resp.setStatus(HttpServletResponse.SC_OK)
resp.getWriter().write("Read ${bytesRead} bytes of request input")
}
] as HttpServlet
configureAndStartServer(mockServlet, -1)
}
private static void logResponse(String response, String s = "Response: ") {
String responseId = String.valueOf(System.currentTimeMillis() % 100)
final String delimiterLine = "\n-----" + responseId + "-----\n"
String formattedResponse = s + delimiterLine + response + delimiterLine
logger.info(formattedResponse)
}
@Test
void testRequestsWithMissingContentLengthHeader() throws Exception {
// This shows that the ContentLengthFilter allows a request that does not have a content-length header.
String response = localConnector.getResponse("POST / HTTP/1.0\r\n\r\n")
assertFalse(StringUtils.containsIgnoreCase(response, "411 Length Required"))
}
/**
* This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends more than
* the max.
*/
@Test
void testShouldRejectRequestWithLongContentLengthHeader() throws Exception {
// Arrange
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
logger.info("Making request with CL: ${LARGE_CLAIM_SIZE_BYTES} and actual length: ${LARGE_PAYLOAD.length()}")
// Act
String response = localConnector.getResponse(requestBody)
logResponse(response)
// Assert
assertTrue(response.contains("413 Payload Too Large"))
}
/**
* This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends less than
* the claim.
*/
@Test
void testShouldRejectRequestWithLongContentLengthHeaderAndSmallPayload() throws Exception {
// Arrange
String incompletePayload = "1" * (SMALL_CLAIM_SIZE_BYTES / 2)
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, incompletePayload)
logger.info("Making request with CL: ${LARGE_CLAIM_SIZE_BYTES} and actual length: ${incompletePayload.length()}")
// Act
String response = localConnector.getResponse(requestBody)
logResponse(response)
// Assert
assertTrue(response.contains("413 Payload Too Large"))
}
/**
* This shows that the ContentLengthFilter <em>allows</em> a request when the client claims less
* than the max + sends more than the max, but restricts the request body to the stated content
* length size.
*/
@Test
void testShouldRejectRequestWithSmallContentLengthHeaderAndLargePayload() throws Exception {
// Arrange
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
logger.info("Making request with CL: ${SMALL_CLAIM_SIZE_BYTES} and actual length: ${LARGE_PAYLOAD.length()}")
// Act
String response = localConnector.getResponse(requestBody)
logResponse(response)
// Assert
assertTrue(response.contains("200"))
assertTrue(response.contains("Bytes-Read: " + SMALL_CLAIM_SIZE_BYTES))
assertTrue(response.contains("Read " + SMALL_CLAIM_SIZE_BYTES + " bytes"))
}
/**
* This shows that the server times out when the client claims less than the max + sends less than the max + sends
* less than it claims to send.
*/
@Test
void testShouldTimeoutRequestWithSmallContentLengthHeaderAndSmallerPayload() throws Exception {
// Arrange
String smallerPayload = SMALL_PAYLOAD[0..(SMALL_PAYLOAD.length() / 2)]
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, smallerPayload)
logger.info("Making request with CL: ${SMALL_CLAIM_SIZE_BYTES} and actual length: ${smallerPayload.length()}")
// Act
String response = localConnector.getResponse(requestBody, 500, TimeUnit.MILLISECONDS)
logResponse(response)
// Assert
assertTrue(response.contains("500 Server Error"))
assertTrue(response.contains("Timeout"))
}
@Test
void testFilterShouldAllowSiteToSiteTransfer() throws Exception {
// Arrange
final String SITE_TO_SITE_POST_REQUEST = "POST /nifi-api/data-transfer/input-ports HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s"
final String siteToSiteRequest = String.format(SITE_TO_SITE_POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
logResponse(siteToSiteRequest, "Request: ")
// Act
String response = localConnector.getResponse(siteToSiteRequest)
logResponse(response)
// Assert
assertTrue(response.contains("200 OK"))
}
@Test
void testJettyMaxFormSize() throws Exception {
// This shows that the jetty server option for 'maxFormContentSize' is insufficient for our needs because it
// catches requests like this:
// Configure the server but do not apply the CLF because the FORM_CONTENT_SIZE > 0
configureAndStartServer(new HttpServlet() {
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
try {
req.getParameterMap()
ServletInputStream input = req.getInputStream()
int count = 0
while (!input.isFinished()) {
input.read()
count += 1
}
final int FORM_LIMIT_BYTES = FORM_CONTENT_SIZE + "a=\n".length()
if (count > FORM_LIMIT_BYTES) {
logger.warn("Bytes read ({}) is larger than the limit ({})", count, FORM_LIMIT_BYTES)
resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Should not reach this code.")
} else {
logger.warn("Bytes read ({}) is less than or equal to the limit ({})", count, FORM_LIMIT_BYTES)
resp.sendError(HttpServletResponse.SC_EXPECTATION_FAILED, "Read Too Many Bytes")
}
} catch (final Exception e) {
// This is the jetty context returning a 400 from the maxFormContentSize setting:
if (StringUtils.containsIgnoreCase(e.getCause().toString(), "Form is larger than max length " + FORM_CONTENT_SIZE)) {
logger.warn("Exception thrown by input stream: ", e)
resp.sendError(HttpServletResponse.SC_REQUEST_ENTITY_TOO_LARGE, "Payload Too Large")
} else {
logger.warn("Exception thrown by input stream: ", e)
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Should not reach this code, either.")
}
}
}
}, FORM_CONTENT_SIZE)
// Test to catch a form submission that exceeds the FORM_CONTENT_SIZE limit
String form = "a=" + "1" * FORM_CONTENT_SIZE
String response = localConnector.getResponse(String.format(FORM_REQUEST, form.length(), form))
logResponse(response)
assertTrue(response.contains("413 Payload Too Large"))
// But it does not catch requests like this:
response = localConnector.getResponse(String.format(POST_REQUEST, form.length(), form + form))
assertTrue(response.contains("417 Read Too Many Bytes"))
}
}

View File

@ -0,0 +1,242 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security;
import org.apache.nifi.authorization.user.NiFiUser;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class ProxiedEntitiesUtilsTest {
private static final String SAFE_USER_NAME_JOHN = "jdoe";
private static final String SAFE_USER_DN_JOHN = "CN=" + SAFE_USER_NAME_JOHN + ", OU=Apache NiFi";
private static final String SAFE_USER_NAME_PROXY_1 = "proxy1.nifi.apache.org";
private static final String SAFE_USER_DN_PROXY_1 = "CN=" + SAFE_USER_NAME_PROXY_1 + ", OU=Apache NiFi";
private static final String SAFE_USER_NAME_PROXY_2 = "proxy2.nifi.apache.org";
private static final String SAFE_USER_DN_PROXY_2 = "CN=" + SAFE_USER_NAME_PROXY_2 + ", OU=Apache NiFi";
private static final String MALICIOUS_USER_NAME_JOHN = SAFE_USER_NAME_JOHN + ", OU=Apache NiFi><CN=" + SAFE_USER_NAME_PROXY_1;
private static final String MALICIOUS_USER_DN_JOHN = "CN=" + MALICIOUS_USER_NAME_JOHN + ", OU=Apache NiFi";
private static final String MALICIOUS_USER_NAME_JOHN_ESCAPED = sanitizeDn(MALICIOUS_USER_NAME_JOHN);
private static final String UNICODE_DN_1 = "CN=Алйс, OU=Apache NiFi";
private static final String UNICODE_DN_1_ENCODED = "<" + base64Encode(UNICODE_DN_1) + ">";
private static final String UNICODE_DN_2 = "CN=Боб, OU=Apache NiFi";
private static final String UNICODE_DN_2_ENCODED = "<" + base64Encode(UNICODE_DN_2) + ">";
private static final String ANONYMOUS_USER = "";
private static final String ANONYMOUS_PROXIED_ENTITY_CHAIN = "<>";
private static String sanitizeDn(String dn) {
return dn.replaceAll(">", "\\\\>").replaceAll("<", "\\\\<");
}
private static String base64Encode(String dn) {
return Base64.getEncoder().encodeToString(dn.getBytes(StandardCharsets.UTF_8));
}
@ParameterizedTest
@MethodSource("getMaliciousNames" )
public void testSanitizeDnShouldHandleFuzzing(String maliciousName) {
assertNotEquals(formatDn(SAFE_USER_NAME_JOHN), ProxiedEntitiesUtils.formatProxyDn(maliciousName));
}
// Contains various attempted >< escapes, trailing NULL, and BACKSPACE + 'n'
private static List<String> getMaliciousNames() {
return Arrays.asList(MALICIOUS_USER_NAME_JOHN,
SAFE_USER_NAME_JOHN + ">",
SAFE_USER_NAME_JOHN + "><>",
SAFE_USER_NAME_JOHN + "\\>",
SAFE_USER_NAME_JOHN + "\u003e",
SAFE_USER_NAME_JOHN + "\u005c\u005c\u003e",
SAFE_USER_NAME_JOHN + "\u0000",
SAFE_USER_NAME_JOHN + "\u0008n");
}
@Test
public void testShouldFormatProxyDn() {
assertEquals(formatDn(SAFE_USER_DN_JOHN), ProxiedEntitiesUtils.formatProxyDn(SAFE_USER_DN_JOHN));
}
@Test
public void testFormatProxyDnShouldHandleMaliciousInput() {
assertEquals(formatSanitizedDn(MALICIOUS_USER_DN_JOHN), ProxiedEntitiesUtils.formatProxyDn(MALICIOUS_USER_DN_JOHN));
}
@Test
public void testGetProxiedEntitiesChain() {
String[] input = new String [] {SAFE_USER_NAME_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2};
assertEquals(formatDns(input), ProxiedEntitiesUtils.getProxiedEntitiesChain(input));
}
@Test
public void testGetProxiedEntitiesChainShouldHandleMaliciousInput() {
final String expectedOutput = formatSanitizedDn(MALICIOUS_USER_DN_JOHN) + formatDns(SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2);
assertEquals(expectedOutput, ProxiedEntitiesUtils.getProxiedEntitiesChain(MALICIOUS_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2));
}
@Test
public void testGetProxiedEntitiesChainShouldEncodeUnicode() {
assertEquals(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED, UNICODE_DN_2_ENCODED),
ProxiedEntitiesUtils.getProxiedEntitiesChain(SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2));
}
@Test
public void testFormatProxyDnShouldEncodeNonAsciiCharacters() {
assertEquals(formatDn(UNICODE_DN_1_ENCODED), ProxiedEntitiesUtils.formatProxyDn(UNICODE_DN_1));
}
@Test
public void testShouldBuildProxyChain(@Mock NiFiUser proxy1, @Mock NiFiUser john) {
when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1);
when(proxy1.getChain()).thenReturn(null);
when(proxy1.isAnonymous()).thenReturn(false);
when(john.getIdentity()).thenReturn(SAFE_USER_NAME_JOHN);
when(john.getChain()).thenReturn(proxy1);
when(john.isAnonymous()).thenReturn(false);
assertEquals(formatDns(SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john));
}
@Test
public void testBuildProxyChainFromNullUserShouldBeAnonymous() {
assertEquals(ANONYMOUS_PROXIED_ENTITY_CHAIN, ProxiedEntitiesUtils.buildProxiedEntitiesChainString(null));
}
@Test
public void testBuildProxyChainFromAnonymousUserShouldBeAnonymous(@Mock NiFiUser proxy1, @Mock NiFiUser anonymous) {
when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1);
when(proxy1.getChain()).thenReturn(null);
when(proxy1.isAnonymous()).thenReturn(false);
when(anonymous.getChain()).thenReturn(proxy1);
when(anonymous.isAnonymous()).thenReturn(true);
assertEquals(formatDns(ANONYMOUS_USER, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(anonymous));
}
@Test
public void testBuildProxyChainShouldHandleUnicode(@Mock NiFiUser proxy1, @Mock NiFiUser john) {
when(proxy1.getIdentity()).thenReturn(UNICODE_DN_1);
when(proxy1.getChain()).thenReturn(null);
when(proxy1.isAnonymous()).thenReturn(false);
when(john.getIdentity()).thenReturn(SAFE_USER_NAME_JOHN);
when(john.getChain()).thenReturn(proxy1);
when(john.isAnonymous()).thenReturn(false);
assertEquals(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john));
}
@Test
public void testBuildProxyChainShouldHandleMaliciousUser(@Mock NiFiUser proxy1, @Mock NiFiUser john) {
when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1);
when(proxy1.getChain()).thenReturn(null);
when(proxy1.isAnonymous()).thenReturn(false);
when(john.getIdentity()).thenReturn(MALICIOUS_USER_NAME_JOHN);
when(john.getChain()).thenReturn(proxy1);
when(john.isAnonymous()).thenReturn(false);
assertEquals(formatDns(MALICIOUS_USER_NAME_JOHN_ESCAPED, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john));
}
@Test
public void testShouldTokenizeProxiedEntitiesChainWithUserNames() {
final List<String> names = Arrays.asList(SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2);
final String rawProxyChain = formatDns(names.toArray(new String[0]));
assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain));
}
@Test
public void testShouldTokenizeAnonymous() {
assertEquals(Collections.singletonList(ANONYMOUS_USER), ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(ANONYMOUS_PROXIED_ENTITY_CHAIN));
}
@Test
public void testShouldTokenizeDoubleAnonymous() {
assertEquals(Arrays.asList(ANONYMOUS_USER, ANONYMOUS_USER), ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(ANONYMOUS_PROXIED_ENTITY_CHAIN.repeat(2)));
}
@Test
public void testShouldTokenizeNestedAnonymous() {
final List<String> names = Arrays.asList(SAFE_USER_DN_PROXY_1, ANONYMOUS_USER, SAFE_USER_DN_PROXY_2);
final String rawProxyChain = formatDns(names.toArray(new String [0]));
assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain));
}
@Test
public void testShouldTokenizeProxiedEntitiesChainWithDNs() {
final List<String> dns = Arrays.asList(SAFE_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2);
final String rawProxyChain = formatDns(dns.toArray(new String[0]));
assertEquals(dns, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain));
}
@Test
public void testShouldTokenizeProxiedEntitiesChainWithAnonymousUser() {
final List<String> names = Arrays.asList(ANONYMOUS_USER, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2);
final String rawProxyChain = formatDns(names.toArray(new String[0]));
assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain));
}
@Test
public void testTokenizeProxiedEntitiesChainShouldHandleMaliciousUser() {
final List<String> names = Arrays.asList(MALICIOUS_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2);
final String rawProxyChain = names.stream()
.map(this::formatSanitizedDn)
.collect(Collectors.joining());
List<String> tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain);
assertEquals(names, tokenizedNames);
assertFalse(tokenizedNames.contains(SAFE_USER_NAME_JOHN));
}
@Test
public void testTokenizeProxiedEntitiesChainShouldDecodeNonAsciiValues() {
List<String> tokenizedNames =
ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED, UNICODE_DN_2_ENCODED));
assertEquals(Arrays.asList(SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2), tokenizedNames);
}
private String formatSanitizedDn(String dn) {
return formatDn((sanitizeDn(dn)));
}
private String formatDn(String dn) {
return formatDns(dn);
}
private String formatDns(String...dns) {
return Arrays.stream(dns)
.collect(Collectors.joining("><", "<", ">"));
}
}

View File

@ -0,0 +1,232 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security.requests;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.stream.io.StreamUtils;
import org.eclipse.jetty.server.LocalConnector;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import javax.servlet.DispatcherType;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.EnumSet;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@ExtendWith(MockitoExtension.class)
class ContentLengthFilterTest {
private static final String POST_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s";
public static final int FORM_CONTENT_SIZE = 128;
// These variables hold data for content small enough to be allowed
private static final int SMALL_CLAIM_SIZE_BYTES = 150;
private static final String SMALL_PAYLOAD = "1".repeat(SMALL_CLAIM_SIZE_BYTES);
// These variables hold data for content too large to be allowed
private static final int LARGE_CLAIM_SIZE_BYTES = 2000;
private static final String LARGE_PAYLOAD = "1".repeat(LARGE_CLAIM_SIZE_BYTES);
private Server serverUnderTest;
private LocalConnector localConnector;
@BeforeEach
public void setUp() throws Exception {
createSimpleReadServer();
}
@AfterEach
public void tearDown() throws Exception {
stopServer();
}
@Test
public void testRequestsWithMissingContentLengthHeader() throws Exception {
// This shows that the ContentLengthFilter allows a request that does not have a content-length header.
String response = localConnector.getResponse("POST / HTTP/1.0\r\n\r\n");
assertFalse(StringUtils.containsIgnoreCase(response, "411 Length Required"));
}
/**
* This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends more than
* the max.
*/
@Test
public void testShouldRejectRequestWithLongContentLengthHeader() throws Exception {
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD);
String response = localConnector.getResponse(requestBody);
assertTrue(response.contains("413 Payload Too Large"));
}
/**
* This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends less than
* the claim.
*/
@Test
public void testShouldRejectRequestWithLongContentLengthHeaderAndSmallPayload() throws Exception {
String incompletePayload = "1".repeat(SMALL_CLAIM_SIZE_BYTES / 2);
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, incompletePayload);
String response = localConnector.getResponse(requestBody);
assertTrue(response.contains("413 Payload Too Large"));
}
/**
* This shows that the ContentLengthFilter <em>allows</em> a request when the client claims less
* than the max + sends more than the max, but restricts the request body to the stated content
* length size.
*/
@Test
public void testShouldRejectRequestWithSmallContentLengthHeaderAndLargePayload() throws Exception {
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, LARGE_PAYLOAD);
String response = localConnector.getResponse(requestBody);
assertTrue(response.contains("200"));
assertTrue(response.contains("Bytes-Read: " + SMALL_CLAIM_SIZE_BYTES));
assertTrue(response.contains("Read " + SMALL_CLAIM_SIZE_BYTES + " bytes"));
}
/**
* This shows that the server times out when the client claims less than the max + sends less than the max + sends
* less than it claims to send.
*/
@Test
public void testShouldTimeoutRequestWithSmallContentLengthHeaderAndSmallerPayload() throws Exception {
String smallerPayload = SMALL_PAYLOAD.substring(0, SMALL_PAYLOAD.length() / 2);
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, smallerPayload);
String response = localConnector.getResponse(requestBody, 500, TimeUnit.MILLISECONDS);
assertTrue(response.contains("500 Server Error"));
assertTrue(response.contains("Timeout"));
}
@Test
public void testFilterShouldAllowSiteToSiteTransfer() throws Exception {
final String siteToSitePostRequest = "POST /nifi-api/data-transfer/input-ports HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s";
final String siteToSiteRequest = String.format(siteToSitePostRequest, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD);
String response = localConnector.getResponse(siteToSiteRequest);
assertTrue(response.contains("200 OK"));
}
@Test
void testJettyMaxFormSize() throws Exception {
// This shows that the jetty server option for 'maxFormContentSize' is insufficient for our needs because it
// catches requests like this:
// Configure the server but do not apply the CLF because the FORM_CONTENT_SIZE > 0
configureAndStartServer(new HttpServlet() {
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
try {
req.getParameterMap();
ServletInputStream input = req.getInputStream();
int count = 0;
while (!input.isFinished()) {
input.read();
count += 1;
}
final int formLimitBytes = FORM_CONTENT_SIZE + "a=\n".length();
if (count > formLimitBytes) {
resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Should not reach this code.");
} else {
resp.sendError(HttpServletResponse.SC_EXPECTATION_FAILED, "Read Too Many Bytes");
}
} catch (final Exception e) {
// This is the jetty context returning a 400 from the maxFormContentSize setting:
if (StringUtils.containsIgnoreCase(e.getCause().toString(), "Form is larger than max length " + FORM_CONTENT_SIZE)) {
resp.sendError(HttpServletResponse.SC_REQUEST_ENTITY_TOO_LARGE, "Payload Too Large");
} else {
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Should not reach this code, either.");
}
}
}
}, FORM_CONTENT_SIZE);
// Test to catch a form submission that exceeds the FORM_CONTENT_SIZE limit
String form = "a=" + "1".repeat(FORM_CONTENT_SIZE);
final String formRequest = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\nContent-Type: application/x-www-form-urlencoded\r\nAccept-Charset: UTF-8\r\n\r\n%s";
String response = localConnector.getResponse(String.format(formRequest, form.length(), form));
assertTrue(response.contains("413 Payload Too Large"));
// But it does not catch requests like this:
response = localConnector.getResponse(String.format(POST_REQUEST, form.length(), form + form));
assertTrue(response.contains("417 Read Too Many Bytes"));
}
/**
* Initializes a server which consumes any provided request input stream and returns HTTP 200. It has no
* {@code maxFormContentSize}, so the {@link ContentLengthFilter} is applied. The response contains a header and the
* response body indicating the total number of request content bytes read.
*
* @throws Exception if there is a problem setting up the server
*/
private void createSimpleReadServer() throws Exception {
HttpServlet mockServlet = new HttpServlet() {
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
byte[] byteBuffer = new byte[2048];
int bytesRead = StreamUtils.fillBuffer(req.getInputStream(), byteBuffer, false);
resp.setHeader("Bytes-Read", Integer.toString(bytesRead));
resp.setStatus(HttpServletResponse.SC_OK);
resp.getWriter().write("Read " + bytesRead + " bytes of request input");
}
};
configureAndStartServer(mockServlet, -1);
}
private void configureAndStartServer(HttpServlet servlet, int maxFormContentSize) throws Exception {
serverUnderTest = new Server();
localConnector = new LocalConnector(serverUnderTest);
localConnector.setIdleTimeout(2500); // only one request needed + value large enough for slow systems
serverUnderTest.addConnector(localConnector);
ServletContextHandler contextUnderTest = new ServletContextHandler(serverUnderTest, "/");
if (maxFormContentSize > 0) {
contextUnderTest.setMaxFormContentSize(maxFormContentSize);
}
contextUnderTest.addServlet(new ServletHolder(servlet), "/*");
// This only adds the ContentLengthFilter if a valid maxFormContentSize is not provided
if (maxFormContentSize < 0) {
FilterHolder holder = contextUnderTest.addFilter(ContentLengthFilter.class, "/*", EnumSet.of(DispatcherType.REQUEST));
holder.setInitParameter(ContentLengthFilter.MAX_LENGTH_INIT_PARAM, String.valueOf(1000));
}
serverUnderTest.start();
}
void stopServer() throws Exception {
if (serverUnderTest != null && serverUnderTest.isRunning()) {
serverUnderTest.stop();
}
}
}