mirror of https://github.com/apache/nifi.git
NIFI-5628 Added content length check to OkHttpReplicationClient.
Added unit tests. This closes #3035
This commit is contained in:
parent
0dd382370b
commit
748cf74562
|
@ -21,6 +21,35 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
|||
import com.fasterxml.jackson.annotation.JsonInclude.Value;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.module.jaxb.JaxbAnnotationIntrospector;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.URI;
|
||||
import java.security.KeyStore;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import java.util.zip.GZIPInputStream;
|
||||
import javax.net.ssl.KeyManager;
|
||||
import javax.net.ssl.KeyManagerFactory;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import javax.ws.rs.HttpMethod;
|
||||
import javax.ws.rs.core.MultivaluedHashMap;
|
||||
import javax.ws.rs.core.MultivaluedMap;
|
||||
import javax.ws.rs.core.Response;
|
||||
import okhttp3.Call;
|
||||
import okhttp3.ConnectionPool;
|
||||
import okhttp3.Headers;
|
||||
|
@ -42,36 +71,6 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.StreamUtils;
|
||||
|
||||
import javax.net.ssl.KeyManager;
|
||||
import javax.net.ssl.KeyManagerFactory;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import javax.ws.rs.HttpMethod;
|
||||
import javax.ws.rs.core.MultivaluedHashMap;
|
||||
import javax.ws.rs.core.MultivaluedMap;
|
||||
import javax.ws.rs.core.Response;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.URI;
|
||||
import java.security.KeyStore;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import java.util.zip.GZIPInputStream;
|
||||
|
||||
public class OkHttpReplicationClient implements HttpReplicationClient {
|
||||
private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClient.class);
|
||||
private static final Set<String> gzipEncodings = Stream.of("gzip", "x-gzip").collect(Collectors.toSet());
|
||||
|
@ -95,12 +94,35 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||
@Override
|
||||
public PreparedRequest prepareRequest(final String method, final Map<String, String> headers, final Object entity) {
|
||||
final boolean gzip = isUseGzip(headers);
|
||||
checkContentLengthHeader(method, headers);
|
||||
final RequestBody requestBody = createRequestBody(headers, entity, gzip);
|
||||
|
||||
final Map<String, String> updatedHeaders = gzip ? updateHeadersForGzip(headers) : headers;
|
||||
return new OkHttpPreparedRequest(method, updatedHeaders, entity, requestBody);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks the content length header on DELETE requests to ensure it is set to '0', avoiding request timeouts on replicated requests.
|
||||
* @param method the HTTP method of the request
|
||||
* @param headers the header keys and values
|
||||
*/
|
||||
private void checkContentLengthHeader(String method, Map<String, String> headers) {
|
||||
// Only applies to DELETE requests
|
||||
if (HttpMethod.DELETE.equalsIgnoreCase(method)) {
|
||||
// Find the Content-Length header if present
|
||||
final String CONTENT_LENGTH_HEADER_KEY = "Content-Length";
|
||||
Map.Entry<String, String> contentLengthEntry = headers.entrySet().stream().filter(entry -> entry.getKey().equalsIgnoreCase(CONTENT_LENGTH_HEADER_KEY)).findFirst().orElse(null);
|
||||
// If no CL header, do nothing
|
||||
if (contentLengthEntry != null) {
|
||||
// If the provided CL value is non-zero, override it
|
||||
if (contentLengthEntry.getValue() != null && !contentLengthEntry.getValue().equalsIgnoreCase("0")) {
|
||||
logger.warn("This is a DELETE request; the provided Content-Length was {}; setting Content-Length to 0", contentLengthEntry.getValue());
|
||||
headers.put(CONTENT_LENGTH_HEADER_KEY, "0");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response replicate(final PreparedRequest request, final String uri) throws IOException {
|
||||
if (!(Objects.requireNonNull(request) instanceof OkHttpPreparedRequest)) {
|
||||
|
@ -140,7 +162,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||
final String contentEncoding = callResponse.header("Content-Encoding");
|
||||
if (gzipEncodings.contains(contentEncoding)) {
|
||||
try (final InputStream gzipIn = new GZIPInputStream(new ByteArrayInputStream(rawBytes));
|
||||
final ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
|
||||
final ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
|
||||
|
||||
StreamUtils.copy(gzipIn, baos);
|
||||
return baos.toByteArray();
|
||||
|
@ -183,7 +205,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||
|
||||
@SuppressWarnings("unchecked")
|
||||
private HttpUrl buildUrl(final OkHttpPreparedRequest request, final String uri) {
|
||||
HttpUrl.Builder urlBuilder = HttpUrl.parse(uri.toString()).newBuilder();
|
||||
HttpUrl.Builder urlBuilder = HttpUrl.parse(uri).newBuilder();
|
||||
switch (request.getMethod().toUpperCase()) {
|
||||
case HttpMethod.DELETE:
|
||||
case HttpMethod.HEAD:
|
||||
|
@ -226,7 +248,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||
|
||||
private byte[] serializeEntity(final Object entity, final String contentType, final boolean gzip) {
|
||||
try (final ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
final OutputStream out = gzip ? new GZIPOutputStream(baos, 1) : baos) {
|
||||
final OutputStream out = gzip ? new GZIPOutputStream(baos, 1) : baos) {
|
||||
|
||||
getSerializer(contentType).serialize(entity, out);
|
||||
out.close();
|
||||
|
@ -269,10 +291,10 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||
} else {
|
||||
final String[] acceptEncodingTokens = rawAcceptEncoding.split(",");
|
||||
return Stream.of(acceptEncodingTokens)
|
||||
.map(String::trim)
|
||||
.filter(StringUtils::isNotEmpty)
|
||||
.map(String::toLowerCase)
|
||||
.anyMatch(gzipEncodings::contains);
|
||||
.map(String::trim)
|
||||
.filter(StringUtils::isNotEmpty)
|
||||
.map(String::toLowerCase)
|
||||
.anyMatch(gzipEncodings::contains);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.nifi.cluster.coordination.http.replication.okhttp
|
||||
|
||||
import org.apache.nifi.properties.StandardNiFiProperties
|
||||
import org.apache.nifi.util.NiFiProperties
|
||||
import org.junit.BeforeClass
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.JUnit4
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
class OkHttpReplicationClientTest extends GroovyTestCase {
|
||||
private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClientTest.class)
|
||||
|
||||
@BeforeClass
|
||||
static void setUpOnce() throws Exception {
|
||||
logger.metaClass.methodMissing = { String name, args ->
|
||||
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
|
||||
}
|
||||
}
|
||||
|
||||
private static StandardNiFiProperties mockNiFiProperties() {
|
||||
[getClusterNodeConnectionTimeout: { -> "10 ms" },
|
||||
getClusterNodeReadTimeout : { -> "10 ms" },
|
||||
getProperty : { String prop ->
|
||||
logger.mock("Requested getProperty(${prop}) -> \"\"")
|
||||
""
|
||||
}] as StandardNiFiProperties
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldReplaceNonZeroContentLengthHeader() {
|
||||
// Arrange
|
||||
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
|
||||
String method = "DELETE"
|
||||
logger.info("Original headers: ${headers}")
|
||||
|
||||
NiFiProperties mockProperties = mockNiFiProperties()
|
||||
|
||||
OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)
|
||||
|
||||
// Act
|
||||
client.checkContentLengthHeader(method, headers)
|
||||
logger.info("Checked headers: ${headers}")
|
||||
|
||||
// Assert
|
||||
assert headers.size() == 2
|
||||
assert headers."Content-Length" == "0"
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldReplaceNonZeroContentLengthHeaderOnDeleteCaseInsensitive() {
|
||||
// Arrange
|
||||
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
|
||||
String method = "delete"
|
||||
logger.info("Original headers: ${headers}")
|
||||
|
||||
NiFiProperties mockProperties = mockNiFiProperties()
|
||||
|
||||
OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)
|
||||
|
||||
// Act
|
||||
client.checkContentLengthHeader(method, headers)
|
||||
logger.info("Checked headers: ${headers}")
|
||||
|
||||
// Assert
|
||||
assert headers.size() == 2
|
||||
assert headers."Content-Length" == "0"
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldNotReplaceContentLengthHeaderWhenZeroOrNull() {
|
||||
// Arrange
|
||||
String method = "DELETE"
|
||||
def zeroOrNullContentLengths = [null, "0"]
|
||||
|
||||
NiFiProperties mockProperties = mockNiFiProperties()
|
||||
|
||||
OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)
|
||||
|
||||
// Act
|
||||
zeroOrNullContentLengths.each { String contentLength ->
|
||||
def headers = ["Content-Length": contentLength, "Other-Header": "arbitrary value"]
|
||||
logger.info("Original headers: ${headers}")
|
||||
|
||||
logger.info("Trying method ${method}")
|
||||
client.checkContentLengthHeader(method, headers)
|
||||
logger.info("Checked headers: ${headers}")
|
||||
|
||||
// Assert
|
||||
assert headers.size() == 2
|
||||
assert headers."Content-Length" == contentLength
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testShouldNotReplaceNonZeroContentLengthHeaderOnOtherMethod() {
|
||||
// Arrange
|
||||
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
|
||||
logger.info("Original headers: ${headers}")
|
||||
|
||||
NiFiProperties mockProperties = mockNiFiProperties()
|
||||
|
||||
OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)
|
||||
|
||||
def nonDeleteMethods = ["POST", "PUT", "GET", "HEAD"]
|
||||
|
||||
// Act
|
||||
nonDeleteMethods.each { String method ->
|
||||
logger.info("Trying method ${method}")
|
||||
client.checkContentLengthHeader(method, headers)
|
||||
logger.info("Checked headers: ${headers}")
|
||||
|
||||
// Assert
|
||||
assert headers.size() == 2
|
||||
assert headers."Content-Length" == "123"
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue