mirror of https://github.com/apache/nifi.git
NIFI-11767 Refactored Groovy tests in nifi-web-error and nifi-web-security to Java
This closes #7457 Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
parent
b3372900b3
commit
d24318cdb8
|
@ -1,135 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.filter
|
||||
|
||||
|
||||
import org.junit.jupiter.api.BeforeAll
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
|
||||
import javax.servlet.FilterChain
|
||||
import javax.servlet.FilterConfig
|
||||
import javax.servlet.RequestDispatcher
|
||||
import javax.servlet.ServletContext
|
||||
import javax.servlet.ServletRequest
|
||||
import javax.servlet.ServletResponse
|
||||
import javax.servlet.http.HttpServletRequest
|
||||
import javax.servlet.http.HttpServletResponse
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals
|
||||
|
||||
class CatchAllFilterTest {
|
||||
private static final Logger logger = LoggerFactory.getLogger(CatchAllFilterTest.class)
|
||||
|
||||
@BeforeAll
|
||||
static void setUpOnce() throws Exception {
|
||||
logger.metaClass.methodMissing = { String name, args ->
|
||||
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
|
||||
}
|
||||
}
|
||||
|
||||
private static String getValue(String parameterName, Map<String, String> params = [:]) {
|
||||
params.containsKey(parameterName) ? params[parameterName] : ""
|
||||
}
|
||||
|
||||
@Test
|
||||
void testInitShouldCallSuper() {
|
||||
// Arrange
|
||||
def EXPECTED_ALLOWED_CONTEXT_PATHS = ["/path1", "/path2"].join(", ")
|
||||
def parameters = [allowedContextPaths: EXPECTED_ALLOWED_CONTEXT_PATHS]
|
||||
FilterConfig mockFilterConfig = [
|
||||
getInitParameter : { String parameterName ->
|
||||
return getValue(parameterName, parameters)
|
||||
},
|
||||
getServletContext: { ->
|
||||
[getInitParameter: { String parameterName ->
|
||||
return getValue(parameterName, parameters)
|
||||
}] as ServletContext
|
||||
}] as FilterConfig
|
||||
|
||||
CatchAllFilter caf = new CatchAllFilter()
|
||||
|
||||
// Act
|
||||
caf.init(mockFilterConfig)
|
||||
logger.info("Allowed context paths: ${caf.getAllowedContextPaths()}")
|
||||
|
||||
// Assert
|
||||
assertEquals(EXPECTED_ALLOWED_CONTEXT_PATHS, caf.getAllowedContextPaths())
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldDoFilter() {
|
||||
// Arrange
|
||||
final String EXPECTED_ALLOWED_CONTEXT_PATHS = ["/path1", "/path2"].join(", ")
|
||||
final String EXPECTED_FORWARD_PATH = "index.jsp"
|
||||
final Map PARAMETERS = [
|
||||
allowedContextPaths: EXPECTED_ALLOWED_CONTEXT_PATHS,
|
||||
forwardPath : EXPECTED_FORWARD_PATH
|
||||
]
|
||||
|
||||
final String EXPECTED_CONTEXT_PATH = ""
|
||||
|
||||
// Mock collaborators
|
||||
FilterConfig mockFilterConfig = [
|
||||
getInitParameter : { String parameterName ->
|
||||
return getValue(parameterName, PARAMETERS)
|
||||
},
|
||||
getServletContext: { ->
|
||||
[getInitParameter: { String parameterName ->
|
||||
return getValue(parameterName, PARAMETERS)
|
||||
}] as ServletContext
|
||||
}] as FilterConfig
|
||||
|
||||
// Local map to store request attributes
|
||||
def requestAttributes = [:]
|
||||
|
||||
// Local string to store resulting path
|
||||
String forwardedRequestTo = ""
|
||||
|
||||
final Map HEADERS = [
|
||||
"X-ProxyContextPath" : "",
|
||||
"X-Forwarded-Context": "",
|
||||
"X-Forwarded-Prefix" : ""]
|
||||
|
||||
HttpServletRequest mockRequest = [
|
||||
getContextPath : { -> EXPECTED_CONTEXT_PATH },
|
||||
getHeader : { String headerName -> getValue(headerName, HEADERS) },
|
||||
setAttribute : { String attr, String value ->
|
||||
requestAttributes[attr] = value
|
||||
logger.mock("Set request attribute ${attr} to ${value}")
|
||||
},
|
||||
getRequestDispatcher: { String path ->
|
||||
[forward: { ServletRequest request, ServletResponse response ->
|
||||
forwardedRequestTo = path
|
||||
logger.mock("Forwarded request to ${path}")
|
||||
}] as RequestDispatcher
|
||||
}] as HttpServletRequest
|
||||
HttpServletResponse mockResponse = [:] as HttpServletResponse
|
||||
FilterChain mockFilterChain = [:] as FilterChain
|
||||
|
||||
CatchAllFilter caf = new CatchAllFilter()
|
||||
caf.init(mockFilterConfig)
|
||||
logger.info("Allowed context paths: ${caf.getAllowedContextPaths()}")
|
||||
|
||||
// Act
|
||||
caf.doFilter(mockRequest, mockResponse, mockFilterChain)
|
||||
|
||||
// Assert
|
||||
assertEquals(EXPECTED_FORWARD_PATH, forwardedRequestTo)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.filter;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.FilterConfig;
|
||||
import javax.servlet.RequestDispatcher;
|
||||
import javax.servlet.ServletContext;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class CatchAllFilterTest {
|
||||
@Mock
|
||||
private ServletContext servletContext;
|
||||
|
||||
@Mock
|
||||
private FilterConfig filterConfig;
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() {
|
||||
when(filterConfig.getServletContext()).thenReturn(servletContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitShouldCallSuper() throws ServletException {
|
||||
String expectedAllowedContextPaths = getExpectedAllowedContextPaths();
|
||||
final Map<String, String> parameters = Collections.singletonMap("allowedContextPaths", getExpectedAllowedContextPaths());
|
||||
when(servletContext.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters));
|
||||
when(filterConfig.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters));
|
||||
|
||||
CatchAllFilter catchAllFilter = new CatchAllFilter();
|
||||
catchAllFilter.init(filterConfig);
|
||||
|
||||
assertEquals(expectedAllowedContextPaths, catchAllFilter.getAllowedContextPaths());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldDoFilter(@Mock HttpServletRequest request, @Mock RequestDispatcher requestDispatcher,
|
||||
@Mock HttpServletResponse response, @Mock FilterChain filterChain ) throws ServletException, IOException {
|
||||
final String expectedAllowedContextPaths = getExpectedAllowedContextPaths();
|
||||
final String expectedForwardPath = "index.jsp";
|
||||
final Map<String, String> parameters = new HashMap<>();
|
||||
parameters.put("allowedContextPaths", expectedAllowedContextPaths);
|
||||
parameters.put("forwardPath", expectedForwardPath);
|
||||
final Map<String, Object> requestAttributes = new HashMap<>();
|
||||
final String[] forwardedRequestTo = new String[1];
|
||||
final Map<String, String> headers = new HashMap<>();
|
||||
headers.put("X-ProxyContextPath", "");
|
||||
headers.put("X-Forwarded-Context", "");
|
||||
headers.put("X-Forwarded-Prefix", "");
|
||||
|
||||
when(servletContext.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters));
|
||||
when(filterConfig.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters));
|
||||
when(request.getHeader(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), headers));
|
||||
doAnswer(invocation -> requestAttributes.put(invocation.getArgument(0), invocation.getArgument(1))).when(request).setAttribute(anyString(), any());
|
||||
when(request.getRequestDispatcher(anyString())).thenAnswer(outerInvocation -> {
|
||||
doAnswer(innerInvocation -> forwardedRequestTo[0] = outerInvocation.getArgument(0)).when(requestDispatcher).forward(any(), any());
|
||||
return requestDispatcher;});
|
||||
|
||||
CatchAllFilter catchAllFilter = new CatchAllFilter();
|
||||
catchAllFilter.init(filterConfig);
|
||||
catchAllFilter.doFilter(request, response, filterChain);
|
||||
|
||||
assertEquals(expectedForwardPath, forwardedRequestTo[0]);
|
||||
}
|
||||
|
||||
private String getExpectedAllowedContextPaths() {
|
||||
return String.join(",", "/path1", "/path2");
|
||||
}
|
||||
|
||||
private static String getValue(String parameterName, Map<String, String> params) {
|
||||
return params.getOrDefault(parameterName, "");
|
||||
}
|
||||
}
|
|
@ -251,11 +251,5 @@
|
|||
<artifactId>jetty-servlet</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.codehaus.groovy</groupId>
|
||||
<artifactId>groovy-json</artifactId>
|
||||
<version>${nifi.groovy.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
|
|
@ -1,393 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.security
|
||||
|
||||
import org.apache.nifi.authorization.user.NiFiUser
|
||||
import org.junit.jupiter.api.BeforeAll
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals
|
||||
|
||||
class ProxiedEntitiesUtilsTest {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ProxiedEntitiesUtils.class)
|
||||
|
||||
private static final String SAFE_USER_NAME_JOHN = "jdoe"
|
||||
private static final String SAFE_USER_DN_JOHN = "CN=${SAFE_USER_NAME_JOHN}, OU=Apache NiFi"
|
||||
|
||||
private static final String SAFE_USER_NAME_PROXY_1 = "proxy1.nifi.apache.org"
|
||||
private static final String SAFE_USER_DN_PROXY_1 = "CN=${SAFE_USER_NAME_PROXY_1}, OU=Apache NiFi"
|
||||
|
||||
private static final String SAFE_USER_NAME_PROXY_2 = "proxy2.nifi.apache.org"
|
||||
private static final String SAFE_USER_DN_PROXY_2 = "CN=${SAFE_USER_NAME_PROXY_2}, OU=Apache NiFi"
|
||||
|
||||
private static
|
||||
final String MALICIOUS_USER_NAME_JOHN = "${SAFE_USER_NAME_JOHN}, OU=Apache NiFi><CN=${SAFE_USER_NAME_PROXY_1}"
|
||||
private static final String MALICIOUS_USER_DN_JOHN = "CN=${MALICIOUS_USER_NAME_JOHN}, OU=Apache NiFi"
|
||||
|
||||
private static
|
||||
final String MALICIOUS_USER_NAME_JOHN_ESCAPED = sanitizeDn(MALICIOUS_USER_NAME_JOHN)
|
||||
|
||||
private static final String UNICODE_DN_1 = "CN=Алйс, OU=Apache NiFi"
|
||||
private static final String UNICODE_DN_1_ENCODED = "<" + base64Encode(UNICODE_DN_1) + ">"
|
||||
|
||||
private static final String UNICODE_DN_2 = "CN=Боб, OU=Apache NiFi"
|
||||
private static final String UNICODE_DN_2_ENCODED = "<" + base64Encode(UNICODE_DN_2) + ">"
|
||||
|
||||
@BeforeAll
|
||||
static void setUpOnce() throws Exception {
|
||||
logger.metaClass.methodMissing = { String name, args ->
|
||||
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
|
||||
}
|
||||
}
|
||||
|
||||
private static String sanitizeDn(String dn = "") {
|
||||
dn.replaceAll(/>/, '\\\\>').replaceAll('<', '\\\\<')
|
||||
}
|
||||
|
||||
private static String base64Encode(String dn = "") {
|
||||
return Base64.getEncoder().encodeToString(dn.getBytes(StandardCharsets.UTF_8))
|
||||
}
|
||||
|
||||
private static String printUnicodeString(final String raw) {
|
||||
StringBuilder sb = new StringBuilder()
|
||||
for (int i = 0; i < raw.size(); i++) {
|
||||
int codePoint = Character.codePointAt(raw, i)
|
||||
int charCount = Character.charCount(codePoint)
|
||||
if (charCount > 1) {
|
||||
i += charCount - 1 // 2.
|
||||
if (i >= raw.length()) {
|
||||
throw new IllegalArgumentException("Code point indicated more characters than available")
|
||||
}
|
||||
}
|
||||
sb.append(String.format("\\u%04x ", codePoint))
|
||||
}
|
||||
return sb.toString().trim()
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSanitizeDnShouldHandleFuzzing() throws Exception {
|
||||
// Arrange
|
||||
final String DESIRED_NAME = SAFE_USER_NAME_JOHN
|
||||
logger.info(" Desired name: ${DESIRED_NAME} | ${printUnicodeString(DESIRED_NAME)}")
|
||||
|
||||
// Contains various attempted >< escapes, trailing NULL, and BACKSPACE + 'n'
|
||||
final List MALICIOUS_NAMES = [MALICIOUS_USER_NAME_JOHN,
|
||||
SAFE_USER_NAME_JOHN + ">",
|
||||
SAFE_USER_NAME_JOHN + "><>",
|
||||
SAFE_USER_NAME_JOHN + "\\>",
|
||||
SAFE_USER_NAME_JOHN + "\u003e",
|
||||
SAFE_USER_NAME_JOHN + "\u005c\u005c\u003e",
|
||||
SAFE_USER_NAME_JOHN + "\u0000",
|
||||
SAFE_USER_NAME_JOHN + "\u0008n"]
|
||||
|
||||
// Act
|
||||
MALICIOUS_NAMES.each { String name ->
|
||||
logger.info(" Raw name: ${name} | ${printUnicodeString(name)}")
|
||||
String sanitizedName = ProxiedEntitiesUtils.sanitizeDn(name)
|
||||
logger.info("Sanitized name: ${sanitizedName} | ${printUnicodeString(sanitizedName)}")
|
||||
|
||||
// Assert
|
||||
assertNotEquals(DESIRED_NAME, sanitizedName)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldFormatProxyDn() throws Exception {
|
||||
// Arrange
|
||||
final String DN = SAFE_USER_DN_JOHN
|
||||
logger.info(" Provided proxy DN: ${DN}")
|
||||
|
||||
final String EXPECTED_PROXY_DN = "<${DN}>"
|
||||
logger.info(" Expected proxy DN: ${EXPECTED_PROXY_DN}")
|
||||
|
||||
// Act
|
||||
String forjohnedProxyDn = ProxiedEntitiesUtils.formatProxyDn(DN)
|
||||
logger.info("Forjohned proxy DN: ${forjohnedProxyDn}")
|
||||
|
||||
// Assert
|
||||
assertEquals(EXPECTED_PROXY_DN, forjohnedProxyDn)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFormatProxyDnShouldHandleMaliciousInput() throws Exception {
|
||||
// Arrange
|
||||
final String DN = MALICIOUS_USER_DN_JOHN
|
||||
logger.info(" Provided proxy DN: ${DN}")
|
||||
|
||||
final String SANITIZED_DN = sanitizeDn(DN)
|
||||
final String EXPECTED_PROXY_DN = "<${SANITIZED_DN}>"
|
||||
logger.info(" Expected proxy DN: ${EXPECTED_PROXY_DN}")
|
||||
|
||||
// Act
|
||||
String forjohnedProxyDn = ProxiedEntitiesUtils.formatProxyDn(DN)
|
||||
logger.info("Forjohned proxy DN: ${forjohnedProxyDn}")
|
||||
|
||||
// Assert
|
||||
assertEquals(EXPECTED_PROXY_DN, forjohnedProxyDn)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetProxiedEntitiesChain() throws Exception {
|
||||
// Arrange
|
||||
String[] input = [SAFE_USER_NAME_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2]
|
||||
final String expectedOutput = "<${SAFE_USER_NAME_JOHN}><${SAFE_USER_DN_PROXY_1}><${SAFE_USER_DN_PROXY_2}>"
|
||||
|
||||
// Act
|
||||
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
|
||||
|
||||
// Assert
|
||||
assertEquals(expectedOutput, output)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetProxiedEntitiesChainShouldHandleMaliciousInput() throws Exception {
|
||||
// Arrange
|
||||
String[] input = [MALICIOUS_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2]
|
||||
final String expectedOutput = "<${sanitizeDn(MALICIOUS_USER_DN_JOHN)}><${SAFE_USER_DN_PROXY_1}><${SAFE_USER_DN_PROXY_2}>"
|
||||
|
||||
// Act
|
||||
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
|
||||
|
||||
// Assert
|
||||
assertEquals(expectedOutput, output)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetProxiedEntitiesChainShouldEncodeUnicode() throws Exception {
|
||||
// Arrange
|
||||
String[] input = [SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2]
|
||||
final String expectedOutput = "<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}><${UNICODE_DN_2_ENCODED}>"
|
||||
|
||||
// Act
|
||||
def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input)
|
||||
|
||||
// Assert
|
||||
assertEquals(expectedOutput, output)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFormatProxyDnShouldEncodeNonAsciiCharacters() throws Exception {
|
||||
// Arrange
|
||||
logger.info(" Provided DN: ${UNICODE_DN_1}")
|
||||
final String expectedFormattedDn = "<${UNICODE_DN_1_ENCODED}>"
|
||||
logger.info(" Expected DN: expected")
|
||||
|
||||
// Act
|
||||
String formattedDn = ProxiedEntitiesUtils.formatProxyDn(UNICODE_DN_1)
|
||||
logger.info("Formatted DN: ${formattedDn}")
|
||||
|
||||
// Assert
|
||||
assertEquals(expectedFormattedDn, formattedDn)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldBuildProxyChain() throws Exception {
|
||||
// Arrange
|
||||
def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser
|
||||
def mockJohn = [getIdentity: { -> SAFE_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser
|
||||
|
||||
// Act
|
||||
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn)
|
||||
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
|
||||
|
||||
// Assert
|
||||
assertEquals("<${SAFE_USER_NAME_JOHN}><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBuildProxyChainFromNullUserShouldBeAnonymous() throws Exception {
|
||||
// Arrange
|
||||
|
||||
// Act
|
||||
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(null)
|
||||
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
|
||||
|
||||
// Assert
|
||||
assertEquals("<>", proxiedEntitiesChain)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBuildProxyChainFromAnonymousUserShouldBeAnonymous() throws Exception {
|
||||
// Arrange
|
||||
def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser
|
||||
def mockAnonymous = [getIdentity: { -> "anonymous" }, getChain: { -> mockProxy1 }, isAnonymous: { -> true}] as NiFiUser
|
||||
|
||||
// Act
|
||||
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockAnonymous)
|
||||
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
|
||||
|
||||
// Assert
|
||||
assertEquals("<><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBuildProxyChainShouldHandleUnicode() throws Exception {
|
||||
// Arrange
|
||||
def mockProxy1 = [getIdentity: { -> UNICODE_DN_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser
|
||||
def mockJohn = [getIdentity: { -> SAFE_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser
|
||||
|
||||
// Act
|
||||
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn)
|
||||
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
|
||||
|
||||
// Assert
|
||||
assertEquals("<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}>" as String, proxiedEntitiesChain)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBuildProxyChainShouldHandleMaliciousUser() throws Exception {
|
||||
// Arrange
|
||||
def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser
|
||||
def mockJohn = [getIdentity: { -> MALICIOUS_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser
|
||||
|
||||
// Act
|
||||
String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn)
|
||||
logger.info("Proxied entities chain: ${proxiedEntitiesChain}")
|
||||
|
||||
// Assert
|
||||
assertEquals("<${MALICIOUS_USER_NAME_JOHN_ESCAPED}><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldTokenizeProxiedEntitiesChainWithUserNames() throws Exception {
|
||||
// Arrange
|
||||
final List NAMES = [SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2]
|
||||
final String RAW_PROXY_CHAIN = "<${NAMES.join("><")}>"
|
||||
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
|
||||
|
||||
// Act
|
||||
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
|
||||
logger.info("Tokenized proxy chain: ${tokenizedNames}")
|
||||
|
||||
// Assert
|
||||
assertEquals(NAMES, tokenizedNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldTokenizeAnonymous() throws Exception {
|
||||
// Arrange
|
||||
final List NAMES = [""]
|
||||
final String RAW_PROXY_CHAIN = "<>"
|
||||
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
|
||||
|
||||
// Act
|
||||
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
|
||||
logger.info("Tokenized proxy chain: ${tokenizedNames}")
|
||||
|
||||
// Assert
|
||||
assertEquals(NAMES, tokenizedNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldTokenizeDoubleAnonymous() throws Exception {
|
||||
// Arrange
|
||||
final List NAMES = ["", ""]
|
||||
final String RAW_PROXY_CHAIN = "<><>"
|
||||
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
|
||||
|
||||
// Act
|
||||
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
|
||||
logger.info("Tokenized proxy chain: ${tokenizedNames}")
|
||||
|
||||
// Assert
|
||||
assertEquals(NAMES, tokenizedNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldTokenizeNestedAnonymous() throws Exception {
|
||||
// Arrange
|
||||
final List NAMES = [SAFE_USER_DN_PROXY_1, "", SAFE_USER_DN_PROXY_2]
|
||||
final String RAW_PROXY_CHAIN = "<${SAFE_USER_DN_PROXY_1}><><${SAFE_USER_DN_PROXY_2}>"
|
||||
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
|
||||
|
||||
// Act
|
||||
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
|
||||
logger.info("Tokenized proxy chain: ${tokenizedNames}")
|
||||
|
||||
// Assert
|
||||
assertEquals(NAMES, tokenizedNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldTokenizeProxiedEntitiesChainWithDNs() throws Exception {
|
||||
// Arrange
|
||||
final List DNS = [SAFE_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2]
|
||||
final String RAW_PROXY_CHAIN = "<${DNS.join("><")}>"
|
||||
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
|
||||
|
||||
// Act
|
||||
def tokenizedDns = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
|
||||
logger.info("Tokenized proxy chain: ${tokenizedDns.collect { "\"${it}\"" }}")
|
||||
|
||||
// Assert
|
||||
assertEquals(DNS, tokenizedDns)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldTokenizeProxiedEntitiesChainWithAnonymousUser() throws Exception {
|
||||
// Arrange
|
||||
final List NAMES = ["", SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2]
|
||||
final String RAW_PROXY_CHAIN = "<${NAMES.join("><")}>"
|
||||
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
|
||||
|
||||
// Act
|
||||
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
|
||||
logger.info("Tokenized proxy chain: ${tokenizedNames}")
|
||||
|
||||
// Assert
|
||||
assertEquals(NAMES, tokenizedNames)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testTokenizeProxiedEntitiesChainShouldHandleMaliciousUser() throws Exception {
|
||||
// Arrange
|
||||
final List NAMES = [MALICIOUS_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2]
|
||||
final String RAW_PROXY_CHAIN = "<${NAMES.collect { sanitizeDn(it) }.join("><")}>"
|
||||
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
|
||||
|
||||
// Act
|
||||
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
|
||||
logger.info("Tokenized proxy chain: ${tokenizedNames.collect { "\"${it}\"" }}")
|
||||
|
||||
// Assert
|
||||
assertEquals(NAMES, tokenizedNames)
|
||||
assertEquals(NAMES.size(), tokenizedNames.size())
|
||||
assertFalse(tokenizedNames.contains(SAFE_USER_NAME_JOHN))
|
||||
}
|
||||
|
||||
@Test
|
||||
void testTokenizeProxiedEntitiesChainShouldDecodeNonAsciiValues() throws Exception {
|
||||
// Arrange
|
||||
final String RAW_PROXY_CHAIN = "<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}><${UNICODE_DN_2_ENCODED}>"
|
||||
final List TOKENIZED_NAMES = [SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2]
|
||||
logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}")
|
||||
|
||||
// Act
|
||||
def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN)
|
||||
logger.info("Tokenized proxy chain: ${tokenizedNames.collect { "\"${it}\"" }}")
|
||||
|
||||
// Assert
|
||||
assertEquals(TOKENIZED_NAMES, tokenizedNames)
|
||||
assertEquals(TOKENIZED_NAMES.size(), tokenizedNames.size())
|
||||
}
|
||||
}
|
|
@ -1,277 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.security.requests
|
||||
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
import org.apache.nifi.stream.io.StreamUtils
|
||||
import org.eclipse.jetty.server.LocalConnector
|
||||
import org.eclipse.jetty.server.Server
|
||||
import org.eclipse.jetty.servlet.FilterHolder
|
||||
import org.eclipse.jetty.servlet.ServletContextHandler
|
||||
import org.eclipse.jetty.servlet.ServletHolder
|
||||
import org.junit.jupiter.api.AfterEach
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
|
||||
import javax.servlet.DispatcherType
|
||||
import javax.servlet.ServletException
|
||||
import javax.servlet.ServletInputStream
|
||||
import javax.servlet.http.HttpServlet
|
||||
import javax.servlet.http.HttpServletRequest
|
||||
import javax.servlet.http.HttpServletResponse
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue
|
||||
|
||||
class ContentLengthFilterTest {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ContentLengthFilterTest.class)
|
||||
|
||||
private static final int MAX_CONTENT_LENGTH = 1000
|
||||
private static final int SERVER_IDLE_TIMEOUT = 2500 // only one request needed + value large enough for slow systems
|
||||
private static final String POST_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s"
|
||||
private static final String FORM_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\nContent-Type: application/x-www-form-urlencoded\r\nAccept-Charset: UTF-8\r\n\r\n%s"
|
||||
public static final int FORM_CONTENT_SIZE = 128
|
||||
|
||||
// These variables hold data for content small enough to be allowed
|
||||
private static final int SMALL_CLAIM_SIZE_BYTES = 150
|
||||
private static final String SMALL_PAYLOAD = "1" * SMALL_CLAIM_SIZE_BYTES
|
||||
|
||||
// These variables hold data for content too large to be allowed
|
||||
private static final int LARGE_CLAIM_SIZE_BYTES = 2000
|
||||
private static final String LARGE_PAYLOAD = "1" * LARGE_CLAIM_SIZE_BYTES
|
||||
|
||||
private Server serverUnderTest
|
||||
private LocalConnector localConnector
|
||||
private ServletContextHandler contextUnderTest
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
createSimpleReadServer()
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
stopServer()
|
||||
}
|
||||
|
||||
void stopServer() throws Exception {
|
||||
if (serverUnderTest && serverUnderTest.isRunning()) {
|
||||
serverUnderTest.stop()
|
||||
}
|
||||
}
|
||||
|
||||
private void configureAndStartServer(HttpServlet servlet, int maxFormContentSize) throws Exception {
|
||||
serverUnderTest = new Server()
|
||||
localConnector = new LocalConnector(serverUnderTest)
|
||||
localConnector.setIdleTimeout(SERVER_IDLE_TIMEOUT)
|
||||
serverUnderTest.addConnector(localConnector)
|
||||
|
||||
contextUnderTest = new ServletContextHandler(serverUnderTest, "/")
|
||||
if (maxFormContentSize > 0) {
|
||||
contextUnderTest.setMaxFormContentSize(maxFormContentSize)
|
||||
}
|
||||
contextUnderTest.addServlet(new ServletHolder(servlet), "/*")
|
||||
|
||||
// This only adds the ContentLengthFilter if a valid maxFormContentSize is not provided
|
||||
if (maxFormContentSize < 0) {
|
||||
FilterHolder holder = contextUnderTest.addFilter(ContentLengthFilter.class, "/*", EnumSet.of(DispatcherType.REQUEST) as EnumSet<DispatcherType>)
|
||||
holder.setInitParameter(ContentLengthFilter.MAX_LENGTH_INIT_PARAM, String.valueOf(MAX_CONTENT_LENGTH))
|
||||
}
|
||||
serverUnderTest.start()
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a server which consumes any provided request input stream and returns HTTP 200. It has no
|
||||
* {@code maxFormContentSize}, so the {@link ContentLengthFilter} is applied. The response contains a header and the
|
||||
* response body indicating the total number of request content bytes read.
|
||||
*
|
||||
* @throws Exception if there is a problem setting up the server
|
||||
*/
|
||||
private void createSimpleReadServer() throws Exception {
|
||||
HttpServlet mockServlet = [
|
||||
doPost: { HttpServletRequest req, HttpServletResponse resp ->
|
||||
byte[] byteBuffer = new byte[2048]
|
||||
int bytesRead = StreamUtils.fillBuffer(req.getInputStream(), byteBuffer, false)
|
||||
resp.setHeader("Bytes-Read", bytesRead as String)
|
||||
resp.setStatus(HttpServletResponse.SC_OK)
|
||||
resp.getWriter().write("Read ${bytesRead} bytes of request input")
|
||||
}
|
||||
] as HttpServlet
|
||||
configureAndStartServer(mockServlet, -1)
|
||||
}
|
||||
|
||||
private static void logResponse(String response, String s = "Response: ") {
|
||||
String responseId = String.valueOf(System.currentTimeMillis() % 100)
|
||||
final String delimiterLine = "\n-----" + responseId + "-----\n"
|
||||
String formattedResponse = s + delimiterLine + response + delimiterLine
|
||||
logger.info(formattedResponse)
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRequestsWithMissingContentLengthHeader() throws Exception {
|
||||
// This shows that the ContentLengthFilter allows a request that does not have a content-length header.
|
||||
String response = localConnector.getResponse("POST / HTTP/1.0\r\n\r\n")
|
||||
assertFalse(StringUtils.containsIgnoreCase(response, "411 Length Required"))
|
||||
}
|
||||
|
||||
/**
|
||||
* This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends more than
|
||||
* the max.
|
||||
*/
|
||||
@Test
|
||||
void testShouldRejectRequestWithLongContentLengthHeader() throws Exception {
|
||||
// Arrange
|
||||
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
|
||||
logger.info("Making request with CL: ${LARGE_CLAIM_SIZE_BYTES} and actual length: ${LARGE_PAYLOAD.length()}")
|
||||
|
||||
// Act
|
||||
String response = localConnector.getResponse(requestBody)
|
||||
logResponse(response)
|
||||
|
||||
// Assert
|
||||
assertTrue(response.contains("413 Payload Too Large"))
|
||||
}
|
||||
|
||||
/**
|
||||
* This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends less than
|
||||
* the claim.
|
||||
*/
|
||||
@Test
|
||||
void testShouldRejectRequestWithLongContentLengthHeaderAndSmallPayload() throws Exception {
|
||||
// Arrange
|
||||
String incompletePayload = "1" * (SMALL_CLAIM_SIZE_BYTES / 2)
|
||||
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, incompletePayload)
|
||||
logger.info("Making request with CL: ${LARGE_CLAIM_SIZE_BYTES} and actual length: ${incompletePayload.length()}")
|
||||
|
||||
// Act
|
||||
String response = localConnector.getResponse(requestBody)
|
||||
logResponse(response)
|
||||
|
||||
// Assert
|
||||
assertTrue(response.contains("413 Payload Too Large"))
|
||||
}
|
||||
|
||||
/**
|
||||
* This shows that the ContentLengthFilter <em>allows</em> a request when the client claims less
|
||||
* than the max + sends more than the max, but restricts the request body to the stated content
|
||||
* length size.
|
||||
*/
|
||||
@Test
|
||||
void testShouldRejectRequestWithSmallContentLengthHeaderAndLargePayload() throws Exception {
|
||||
// Arrange
|
||||
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
|
||||
logger.info("Making request with CL: ${SMALL_CLAIM_SIZE_BYTES} and actual length: ${LARGE_PAYLOAD.length()}")
|
||||
|
||||
// Act
|
||||
String response = localConnector.getResponse(requestBody)
|
||||
logResponse(response)
|
||||
|
||||
// Assert
|
||||
assertTrue(response.contains("200"))
|
||||
assertTrue(response.contains("Bytes-Read: " + SMALL_CLAIM_SIZE_BYTES))
|
||||
assertTrue(response.contains("Read " + SMALL_CLAIM_SIZE_BYTES + " bytes"))
|
||||
}
|
||||
|
||||
/**
|
||||
* This shows that the server times out when the client claims less than the max + sends less than the max + sends
|
||||
* less than it claims to send.
|
||||
*/
|
||||
@Test
|
||||
void testShouldTimeoutRequestWithSmallContentLengthHeaderAndSmallerPayload() throws Exception {
|
||||
// Arrange
|
||||
String smallerPayload = SMALL_PAYLOAD[0..(SMALL_PAYLOAD.length() / 2)]
|
||||
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, smallerPayload)
|
||||
logger.info("Making request with CL: ${SMALL_CLAIM_SIZE_BYTES} and actual length: ${smallerPayload.length()}")
|
||||
|
||||
// Act
|
||||
String response = localConnector.getResponse(requestBody, 500, TimeUnit.MILLISECONDS)
|
||||
logResponse(response)
|
||||
|
||||
// Assert
|
||||
assertTrue(response.contains("500 Server Error"))
|
||||
assertTrue(response.contains("Timeout"))
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFilterShouldAllowSiteToSiteTransfer() throws Exception {
|
||||
// Arrange
|
||||
final String SITE_TO_SITE_POST_REQUEST = "POST /nifi-api/data-transfer/input-ports HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s"
|
||||
|
||||
final String siteToSiteRequest = String.format(SITE_TO_SITE_POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD)
|
||||
logResponse(siteToSiteRequest, "Request: ")
|
||||
|
||||
// Act
|
||||
String response = localConnector.getResponse(siteToSiteRequest)
|
||||
logResponse(response)
|
||||
|
||||
// Assert
|
||||
assertTrue(response.contains("200 OK"))
|
||||
}
|
||||
|
||||
@Test
|
||||
void testJettyMaxFormSize() throws Exception {
|
||||
// This shows that the jetty server option for 'maxFormContentSize' is insufficient for our needs because it
|
||||
// catches requests like this:
|
||||
|
||||
// Configure the server but do not apply the CLF because the FORM_CONTENT_SIZE > 0
|
||||
configureAndStartServer(new HttpServlet() {
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
|
||||
try {
|
||||
req.getParameterMap()
|
||||
ServletInputStream input = req.getInputStream()
|
||||
int count = 0
|
||||
while (!input.isFinished()) {
|
||||
input.read()
|
||||
count += 1
|
||||
}
|
||||
final int FORM_LIMIT_BYTES = FORM_CONTENT_SIZE + "a=\n".length()
|
||||
if (count > FORM_LIMIT_BYTES) {
|
||||
logger.warn("Bytes read ({}) is larger than the limit ({})", count, FORM_LIMIT_BYTES)
|
||||
resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Should not reach this code.")
|
||||
} else {
|
||||
logger.warn("Bytes read ({}) is less than or equal to the limit ({})", count, FORM_LIMIT_BYTES)
|
||||
resp.sendError(HttpServletResponse.SC_EXPECTATION_FAILED, "Read Too Many Bytes")
|
||||
}
|
||||
} catch (final Exception e) {
|
||||
// This is the jetty context returning a 400 from the maxFormContentSize setting:
|
||||
if (StringUtils.containsIgnoreCase(e.getCause().toString(), "Form is larger than max length " + FORM_CONTENT_SIZE)) {
|
||||
logger.warn("Exception thrown by input stream: ", e)
|
||||
resp.sendError(HttpServletResponse.SC_REQUEST_ENTITY_TOO_LARGE, "Payload Too Large")
|
||||
} else {
|
||||
logger.warn("Exception thrown by input stream: ", e)
|
||||
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Should not reach this code, either.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}, FORM_CONTENT_SIZE)
|
||||
|
||||
// Test to catch a form submission that exceeds the FORM_CONTENT_SIZE limit
|
||||
String form = "a=" + "1" * FORM_CONTENT_SIZE
|
||||
String response = localConnector.getResponse(String.format(FORM_REQUEST, form.length(), form))
|
||||
logResponse(response)
|
||||
assertTrue(response.contains("413 Payload Too Large"))
|
||||
|
||||
|
||||
// But it does not catch requests like this:
|
||||
response = localConnector.getResponse(String.format(POST_REQUEST, form.length(), form + form))
|
||||
assertTrue(response.contains("417 Read Too Many Bytes"))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.security;
|
||||
|
||||
import org.apache.nifi.authorization.user.NiFiUser;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class ProxiedEntitiesUtilsTest {
|
||||
private static final String SAFE_USER_NAME_JOHN = "jdoe";
|
||||
private static final String SAFE_USER_DN_JOHN = "CN=" + SAFE_USER_NAME_JOHN + ", OU=Apache NiFi";
|
||||
private static final String SAFE_USER_NAME_PROXY_1 = "proxy1.nifi.apache.org";
|
||||
private static final String SAFE_USER_DN_PROXY_1 = "CN=" + SAFE_USER_NAME_PROXY_1 + ", OU=Apache NiFi";
|
||||
private static final String SAFE_USER_NAME_PROXY_2 = "proxy2.nifi.apache.org";
|
||||
private static final String SAFE_USER_DN_PROXY_2 = "CN=" + SAFE_USER_NAME_PROXY_2 + ", OU=Apache NiFi";
|
||||
private static final String MALICIOUS_USER_NAME_JOHN = SAFE_USER_NAME_JOHN + ", OU=Apache NiFi><CN=" + SAFE_USER_NAME_PROXY_1;
|
||||
private static final String MALICIOUS_USER_DN_JOHN = "CN=" + MALICIOUS_USER_NAME_JOHN + ", OU=Apache NiFi";
|
||||
private static final String MALICIOUS_USER_NAME_JOHN_ESCAPED = sanitizeDn(MALICIOUS_USER_NAME_JOHN);
|
||||
private static final String UNICODE_DN_1 = "CN=Алйс, OU=Apache NiFi";
|
||||
private static final String UNICODE_DN_1_ENCODED = "<" + base64Encode(UNICODE_DN_1) + ">";
|
||||
private static final String UNICODE_DN_2 = "CN=Боб, OU=Apache NiFi";
|
||||
private static final String UNICODE_DN_2_ENCODED = "<" + base64Encode(UNICODE_DN_2) + ">";
|
||||
private static final String ANONYMOUS_USER = "";
|
||||
private static final String ANONYMOUS_PROXIED_ENTITY_CHAIN = "<>";
|
||||
|
||||
private static String sanitizeDn(String dn) {
|
||||
return dn.replaceAll(">", "\\\\>").replaceAll("<", "\\\\<");
|
||||
}
|
||||
|
||||
private static String base64Encode(String dn) {
|
||||
return Base64.getEncoder().encodeToString(dn.getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("getMaliciousNames" )
|
||||
public void testSanitizeDnShouldHandleFuzzing(String maliciousName) {
|
||||
assertNotEquals(formatDn(SAFE_USER_NAME_JOHN), ProxiedEntitiesUtils.formatProxyDn(maliciousName));
|
||||
}
|
||||
|
||||
// Contains various attempted >< escapes, trailing NULL, and BACKSPACE + 'n'
|
||||
private static List<String> getMaliciousNames() {
|
||||
return Arrays.asList(MALICIOUS_USER_NAME_JOHN,
|
||||
SAFE_USER_NAME_JOHN + ">",
|
||||
SAFE_USER_NAME_JOHN + "><>",
|
||||
SAFE_USER_NAME_JOHN + "\\>",
|
||||
SAFE_USER_NAME_JOHN + "\u003e",
|
||||
SAFE_USER_NAME_JOHN + "\u005c\u005c\u003e",
|
||||
SAFE_USER_NAME_JOHN + "\u0000",
|
||||
SAFE_USER_NAME_JOHN + "\u0008n");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldFormatProxyDn() {
|
||||
assertEquals(formatDn(SAFE_USER_DN_JOHN), ProxiedEntitiesUtils.formatProxyDn(SAFE_USER_DN_JOHN));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFormatProxyDnShouldHandleMaliciousInput() {
|
||||
assertEquals(formatSanitizedDn(MALICIOUS_USER_DN_JOHN), ProxiedEntitiesUtils.formatProxyDn(MALICIOUS_USER_DN_JOHN));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetProxiedEntitiesChain() {
|
||||
String[] input = new String [] {SAFE_USER_NAME_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2};
|
||||
assertEquals(formatDns(input), ProxiedEntitiesUtils.getProxiedEntitiesChain(input));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetProxiedEntitiesChainShouldHandleMaliciousInput() {
|
||||
final String expectedOutput = formatSanitizedDn(MALICIOUS_USER_DN_JOHN) + formatDns(SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2);
|
||||
assertEquals(expectedOutput, ProxiedEntitiesUtils.getProxiedEntitiesChain(MALICIOUS_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetProxiedEntitiesChainShouldEncodeUnicode() {
|
||||
assertEquals(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED, UNICODE_DN_2_ENCODED),
|
||||
ProxiedEntitiesUtils.getProxiedEntitiesChain(SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFormatProxyDnShouldEncodeNonAsciiCharacters() {
|
||||
assertEquals(formatDn(UNICODE_DN_1_ENCODED), ProxiedEntitiesUtils.formatProxyDn(UNICODE_DN_1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldBuildProxyChain(@Mock NiFiUser proxy1, @Mock NiFiUser john) {
|
||||
when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1);
|
||||
when(proxy1.getChain()).thenReturn(null);
|
||||
when(proxy1.isAnonymous()).thenReturn(false);
|
||||
when(john.getIdentity()).thenReturn(SAFE_USER_NAME_JOHN);
|
||||
when(john.getChain()).thenReturn(proxy1);
|
||||
when(john.isAnonymous()).thenReturn(false);
|
||||
|
||||
assertEquals(formatDns(SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBuildProxyChainFromNullUserShouldBeAnonymous() {
|
||||
assertEquals(ANONYMOUS_PROXIED_ENTITY_CHAIN, ProxiedEntitiesUtils.buildProxiedEntitiesChainString(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBuildProxyChainFromAnonymousUserShouldBeAnonymous(@Mock NiFiUser proxy1, @Mock NiFiUser anonymous) {
|
||||
when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1);
|
||||
when(proxy1.getChain()).thenReturn(null);
|
||||
when(proxy1.isAnonymous()).thenReturn(false);
|
||||
when(anonymous.getChain()).thenReturn(proxy1);
|
||||
when(anonymous.isAnonymous()).thenReturn(true);
|
||||
|
||||
assertEquals(formatDns(ANONYMOUS_USER, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(anonymous));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBuildProxyChainShouldHandleUnicode(@Mock NiFiUser proxy1, @Mock NiFiUser john) {
|
||||
when(proxy1.getIdentity()).thenReturn(UNICODE_DN_1);
|
||||
when(proxy1.getChain()).thenReturn(null);
|
||||
when(proxy1.isAnonymous()).thenReturn(false);
|
||||
when(john.getIdentity()).thenReturn(SAFE_USER_NAME_JOHN);
|
||||
when(john.getChain()).thenReturn(proxy1);
|
||||
when(john.isAnonymous()).thenReturn(false);
|
||||
|
||||
assertEquals(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBuildProxyChainShouldHandleMaliciousUser(@Mock NiFiUser proxy1, @Mock NiFiUser john) {
|
||||
when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1);
|
||||
when(proxy1.getChain()).thenReturn(null);
|
||||
when(proxy1.isAnonymous()).thenReturn(false);
|
||||
when(john.getIdentity()).thenReturn(MALICIOUS_USER_NAME_JOHN);
|
||||
when(john.getChain()).thenReturn(proxy1);
|
||||
when(john.isAnonymous()).thenReturn(false);
|
||||
|
||||
assertEquals(formatDns(MALICIOUS_USER_NAME_JOHN_ESCAPED, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldTokenizeProxiedEntitiesChainWithUserNames() {
|
||||
final List<String> names = Arrays.asList(SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2);
|
||||
final String rawProxyChain = formatDns(names.toArray(new String[0]));
|
||||
|
||||
assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldTokenizeAnonymous() {
|
||||
assertEquals(Collections.singletonList(ANONYMOUS_USER), ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(ANONYMOUS_PROXIED_ENTITY_CHAIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldTokenizeDoubleAnonymous() {
|
||||
assertEquals(Arrays.asList(ANONYMOUS_USER, ANONYMOUS_USER), ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(ANONYMOUS_PROXIED_ENTITY_CHAIN.repeat(2)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldTokenizeNestedAnonymous() {
|
||||
final List<String> names = Arrays.asList(SAFE_USER_DN_PROXY_1, ANONYMOUS_USER, SAFE_USER_DN_PROXY_2);
|
||||
final String rawProxyChain = formatDns(names.toArray(new String [0]));
|
||||
|
||||
assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldTokenizeProxiedEntitiesChainWithDNs() {
|
||||
final List<String> dns = Arrays.asList(SAFE_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2);
|
||||
final String rawProxyChain = formatDns(dns.toArray(new String[0]));
|
||||
|
||||
assertEquals(dns, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldTokenizeProxiedEntitiesChainWithAnonymousUser() {
|
||||
final List<String> names = Arrays.asList(ANONYMOUS_USER, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2);
|
||||
final String rawProxyChain = formatDns(names.toArray(new String[0]));
|
||||
|
||||
assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTokenizeProxiedEntitiesChainShouldHandleMaliciousUser() {
|
||||
final List<String> names = Arrays.asList(MALICIOUS_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2);
|
||||
final String rawProxyChain = names.stream()
|
||||
.map(this::formatSanitizedDn)
|
||||
.collect(Collectors.joining());
|
||||
List<String> tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain);
|
||||
|
||||
assertEquals(names, tokenizedNames);
|
||||
assertFalse(tokenizedNames.contains(SAFE_USER_NAME_JOHN));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTokenizeProxiedEntitiesChainShouldDecodeNonAsciiValues() {
|
||||
List<String> tokenizedNames =
|
||||
ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED, UNICODE_DN_2_ENCODED));
|
||||
|
||||
assertEquals(Arrays.asList(SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2), tokenizedNames);
|
||||
}
|
||||
|
||||
private String formatSanitizedDn(String dn) {
|
||||
return formatDn((sanitizeDn(dn)));
|
||||
}
|
||||
|
||||
private String formatDn(String dn) {
|
||||
return formatDns(dn);
|
||||
}
|
||||
|
||||
private String formatDns(String...dns) {
|
||||
return Arrays.stream(dns)
|
||||
.collect(Collectors.joining("><", "<", ">"));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,232 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.security.requests;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.nifi.stream.io.StreamUtils;
|
||||
import org.eclipse.jetty.server.LocalConnector;
|
||||
import org.eclipse.jetty.server.Server;
|
||||
import org.eclipse.jetty.servlet.FilterHolder;
|
||||
import org.eclipse.jetty.servlet.ServletContextHandler;
|
||||
import org.eclipse.jetty.servlet.ServletHolder;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import javax.servlet.DispatcherType;
|
||||
import javax.servlet.ServletInputStream;
|
||||
import javax.servlet.http.HttpServlet;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class ContentLengthFilterTest {
|
||||
private static final String POST_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s";
|
||||
public static final int FORM_CONTENT_SIZE = 128;
|
||||
|
||||
// These variables hold data for content small enough to be allowed
|
||||
private static final int SMALL_CLAIM_SIZE_BYTES = 150;
|
||||
private static final String SMALL_PAYLOAD = "1".repeat(SMALL_CLAIM_SIZE_BYTES);
|
||||
|
||||
// These variables hold data for content too large to be allowed
|
||||
private static final int LARGE_CLAIM_SIZE_BYTES = 2000;
|
||||
private static final String LARGE_PAYLOAD = "1".repeat(LARGE_CLAIM_SIZE_BYTES);
|
||||
|
||||
private Server serverUnderTest;
|
||||
private LocalConnector localConnector;
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws Exception {
|
||||
createSimpleReadServer();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void tearDown() throws Exception {
|
||||
stopServer();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRequestsWithMissingContentLengthHeader() throws Exception {
|
||||
// This shows that the ContentLengthFilter allows a request that does not have a content-length header.
|
||||
String response = localConnector.getResponse("POST / HTTP/1.0\r\n\r\n");
|
||||
assertFalse(StringUtils.containsIgnoreCase(response, "411 Length Required"));
|
||||
}
|
||||
|
||||
/**
|
||||
* This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends more than
|
||||
* the max.
|
||||
*/
|
||||
@Test
|
||||
public void testShouldRejectRequestWithLongContentLengthHeader() throws Exception {
|
||||
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD);
|
||||
String response = localConnector.getResponse(requestBody);
|
||||
|
||||
assertTrue(response.contains("413 Payload Too Large"));
|
||||
}
|
||||
|
||||
/**
|
||||
* This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends less than
|
||||
* the claim.
|
||||
*/
|
||||
@Test
|
||||
public void testShouldRejectRequestWithLongContentLengthHeaderAndSmallPayload() throws Exception {
|
||||
String incompletePayload = "1".repeat(SMALL_CLAIM_SIZE_BYTES / 2);
|
||||
final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, incompletePayload);
|
||||
String response = localConnector.getResponse(requestBody);
|
||||
|
||||
assertTrue(response.contains("413 Payload Too Large"));
|
||||
}
|
||||
|
||||
/**
|
||||
* This shows that the ContentLengthFilter <em>allows</em> a request when the client claims less
|
||||
* than the max + sends more than the max, but restricts the request body to the stated content
|
||||
* length size.
|
||||
*/
|
||||
@Test
|
||||
public void testShouldRejectRequestWithSmallContentLengthHeaderAndLargePayload() throws Exception {
|
||||
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, LARGE_PAYLOAD);
|
||||
String response = localConnector.getResponse(requestBody);
|
||||
|
||||
assertTrue(response.contains("200"));
|
||||
assertTrue(response.contains("Bytes-Read: " + SMALL_CLAIM_SIZE_BYTES));
|
||||
assertTrue(response.contains("Read " + SMALL_CLAIM_SIZE_BYTES + " bytes"));
|
||||
}
|
||||
|
||||
/**
|
||||
* This shows that the server times out when the client claims less than the max + sends less than the max + sends
|
||||
* less than it claims to send.
|
||||
*/
|
||||
@Test
|
||||
public void testShouldTimeoutRequestWithSmallContentLengthHeaderAndSmallerPayload() throws Exception {
|
||||
String smallerPayload = SMALL_PAYLOAD.substring(0, SMALL_PAYLOAD.length() / 2);
|
||||
final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, smallerPayload);
|
||||
String response = localConnector.getResponse(requestBody, 500, TimeUnit.MILLISECONDS);
|
||||
|
||||
assertTrue(response.contains("500 Server Error"));
|
||||
assertTrue(response.contains("Timeout"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFilterShouldAllowSiteToSiteTransfer() throws Exception {
|
||||
final String siteToSitePostRequest = "POST /nifi-api/data-transfer/input-ports HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s";
|
||||
final String siteToSiteRequest = String.format(siteToSitePostRequest, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD);
|
||||
String response = localConnector.getResponse(siteToSiteRequest);
|
||||
|
||||
assertTrue(response.contains("200 OK"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testJettyMaxFormSize() throws Exception {
|
||||
// This shows that the jetty server option for 'maxFormContentSize' is insufficient for our needs because it
|
||||
// catches requests like this:
|
||||
|
||||
// Configure the server but do not apply the CLF because the FORM_CONTENT_SIZE > 0
|
||||
configureAndStartServer(new HttpServlet() {
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
|
||||
try {
|
||||
req.getParameterMap();
|
||||
ServletInputStream input = req.getInputStream();
|
||||
int count = 0;
|
||||
while (!input.isFinished()) {
|
||||
input.read();
|
||||
count += 1;
|
||||
}
|
||||
final int formLimitBytes = FORM_CONTENT_SIZE + "a=\n".length();
|
||||
if (count > formLimitBytes) {
|
||||
resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Should not reach this code.");
|
||||
} else {
|
||||
resp.sendError(HttpServletResponse.SC_EXPECTATION_FAILED, "Read Too Many Bytes");
|
||||
}
|
||||
} catch (final Exception e) {
|
||||
// This is the jetty context returning a 400 from the maxFormContentSize setting:
|
||||
if (StringUtils.containsIgnoreCase(e.getCause().toString(), "Form is larger than max length " + FORM_CONTENT_SIZE)) {
|
||||
resp.sendError(HttpServletResponse.SC_REQUEST_ENTITY_TOO_LARGE, "Payload Too Large");
|
||||
} else {
|
||||
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Should not reach this code, either.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}, FORM_CONTENT_SIZE);
|
||||
|
||||
// Test to catch a form submission that exceeds the FORM_CONTENT_SIZE limit
|
||||
String form = "a=" + "1".repeat(FORM_CONTENT_SIZE);
|
||||
final String formRequest = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\nContent-Type: application/x-www-form-urlencoded\r\nAccept-Charset: UTF-8\r\n\r\n%s";
|
||||
String response = localConnector.getResponse(String.format(formRequest, form.length(), form));
|
||||
|
||||
assertTrue(response.contains("413 Payload Too Large"));
|
||||
|
||||
// But it does not catch requests like this:
|
||||
response = localConnector.getResponse(String.format(POST_REQUEST, form.length(), form + form));
|
||||
assertTrue(response.contains("417 Read Too Many Bytes"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a server which consumes any provided request input stream and returns HTTP 200. It has no
|
||||
* {@code maxFormContentSize}, so the {@link ContentLengthFilter} is applied. The response contains a header and the
|
||||
* response body indicating the total number of request content bytes read.
|
||||
*
|
||||
* @throws Exception if there is a problem setting up the server
|
||||
*/
|
||||
private void createSimpleReadServer() throws Exception {
|
||||
HttpServlet mockServlet = new HttpServlet() {
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
|
||||
byte[] byteBuffer = new byte[2048];
|
||||
int bytesRead = StreamUtils.fillBuffer(req.getInputStream(), byteBuffer, false);
|
||||
resp.setHeader("Bytes-Read", Integer.toString(bytesRead));
|
||||
resp.setStatus(HttpServletResponse.SC_OK);
|
||||
resp.getWriter().write("Read " + bytesRead + " bytes of request input");
|
||||
}
|
||||
};
|
||||
configureAndStartServer(mockServlet, -1);
|
||||
}
|
||||
|
||||
private void configureAndStartServer(HttpServlet servlet, int maxFormContentSize) throws Exception {
|
||||
serverUnderTest = new Server();
|
||||
localConnector = new LocalConnector(serverUnderTest);
|
||||
localConnector.setIdleTimeout(2500); // only one request needed + value large enough for slow systems
|
||||
serverUnderTest.addConnector(localConnector);
|
||||
|
||||
ServletContextHandler contextUnderTest = new ServletContextHandler(serverUnderTest, "/");
|
||||
if (maxFormContentSize > 0) {
|
||||
contextUnderTest.setMaxFormContentSize(maxFormContentSize);
|
||||
}
|
||||
contextUnderTest.addServlet(new ServletHolder(servlet), "/*");
|
||||
|
||||
// This only adds the ContentLengthFilter if a valid maxFormContentSize is not provided
|
||||
if (maxFormContentSize < 0) {
|
||||
FilterHolder holder = contextUnderTest.addFilter(ContentLengthFilter.class, "/*", EnumSet.of(DispatcherType.REQUEST));
|
||||
holder.setInitParameter(ContentLengthFilter.MAX_LENGTH_INIT_PARAM, String.valueOf(1000));
|
||||
}
|
||||
serverUnderTest.start();
|
||||
}
|
||||
|
||||
void stopServer() throws Exception {
|
||||
if (serverUnderTest != null && serverUnderTest.isRunning()) {
|
||||
serverUnderTest.stop();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue