mirror of
https://github.com/apache/nifi.git
synced 2025-02-07 18:48:51 +00:00
NIFI-5628 Added content length check to OkHttpReplicationClient.
Added unit tests. This closes #3035
This commit is contained in:
parent
0dd382370b
commit
748cf74562
@ -21,6 +21,35 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
|||||||
import com.fasterxml.jackson.annotation.JsonInclude.Value;
|
import com.fasterxml.jackson.annotation.JsonInclude.Value;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.fasterxml.jackson.module.jaxb.JaxbAnnotationIntrospector;
|
import com.fasterxml.jackson.module.jaxb.JaxbAnnotationIntrospector;
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.security.KeyStore;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
import java.util.zip.GZIPInputStream;
|
||||||
|
import javax.net.ssl.KeyManager;
|
||||||
|
import javax.net.ssl.KeyManagerFactory;
|
||||||
|
import javax.net.ssl.SSLContext;
|
||||||
|
import javax.net.ssl.SSLSocketFactory;
|
||||||
|
import javax.net.ssl.TrustManager;
|
||||||
|
import javax.net.ssl.TrustManagerFactory;
|
||||||
|
import javax.net.ssl.X509TrustManager;
|
||||||
|
import javax.ws.rs.HttpMethod;
|
||||||
|
import javax.ws.rs.core.MultivaluedHashMap;
|
||||||
|
import javax.ws.rs.core.MultivaluedMap;
|
||||||
|
import javax.ws.rs.core.Response;
|
||||||
import okhttp3.Call;
|
import okhttp3.Call;
|
||||||
import okhttp3.ConnectionPool;
|
import okhttp3.ConnectionPool;
|
||||||
import okhttp3.Headers;
|
import okhttp3.Headers;
|
||||||
@ -42,36 +71,6 @@ import org.slf4j.Logger;
|
|||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.util.StreamUtils;
|
import org.springframework.util.StreamUtils;
|
||||||
|
|
||||||
import javax.net.ssl.KeyManager;
|
|
||||||
import javax.net.ssl.KeyManagerFactory;
|
|
||||||
import javax.net.ssl.SSLContext;
|
|
||||||
import javax.net.ssl.SSLSocketFactory;
|
|
||||||
import javax.net.ssl.TrustManager;
|
|
||||||
import javax.net.ssl.TrustManagerFactory;
|
|
||||||
import javax.net.ssl.X509TrustManager;
|
|
||||||
import javax.ws.rs.HttpMethod;
|
|
||||||
import javax.ws.rs.core.MultivaluedHashMap;
|
|
||||||
import javax.ws.rs.core.MultivaluedMap;
|
|
||||||
import javax.ws.rs.core.Response;
|
|
||||||
import java.io.ByteArrayInputStream;
|
|
||||||
import java.io.ByteArrayOutputStream;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.security.KeyStore;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Map.Entry;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import java.util.stream.Stream;
|
|
||||||
import java.util.zip.GZIPInputStream;
|
|
||||||
|
|
||||||
public class OkHttpReplicationClient implements HttpReplicationClient {
|
public class OkHttpReplicationClient implements HttpReplicationClient {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClient.class);
|
private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClient.class);
|
||||||
private static final Set<String> gzipEncodings = Stream.of("gzip", "x-gzip").collect(Collectors.toSet());
|
private static final Set<String> gzipEncodings = Stream.of("gzip", "x-gzip").collect(Collectors.toSet());
|
||||||
@ -95,12 +94,35 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||||||
@Override
|
@Override
|
||||||
public PreparedRequest prepareRequest(final String method, final Map<String, String> headers, final Object entity) {
|
public PreparedRequest prepareRequest(final String method, final Map<String, String> headers, final Object entity) {
|
||||||
final boolean gzip = isUseGzip(headers);
|
final boolean gzip = isUseGzip(headers);
|
||||||
|
checkContentLengthHeader(method, headers);
|
||||||
final RequestBody requestBody = createRequestBody(headers, entity, gzip);
|
final RequestBody requestBody = createRequestBody(headers, entity, gzip);
|
||||||
|
|
||||||
final Map<String, String> updatedHeaders = gzip ? updateHeadersForGzip(headers) : headers;
|
final Map<String, String> updatedHeaders = gzip ? updateHeadersForGzip(headers) : headers;
|
||||||
return new OkHttpPreparedRequest(method, updatedHeaders, entity, requestBody);
|
return new OkHttpPreparedRequest(method, updatedHeaders, entity, requestBody);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks the content length header on DELETE requests to ensure it is set to '0', avoiding request timeouts on replicated requests.
|
||||||
|
* @param method the HTTP method of the request
|
||||||
|
* @param headers the header keys and values
|
||||||
|
*/
|
||||||
|
private void checkContentLengthHeader(String method, Map<String, String> headers) {
|
||||||
|
// Only applies to DELETE requests
|
||||||
|
if (HttpMethod.DELETE.equalsIgnoreCase(method)) {
|
||||||
|
// Find the Content-Length header if present
|
||||||
|
final String CONTENT_LENGTH_HEADER_KEY = "Content-Length";
|
||||||
|
Map.Entry<String, String> contentLengthEntry = headers.entrySet().stream().filter(entry -> entry.getKey().equalsIgnoreCase(CONTENT_LENGTH_HEADER_KEY)).findFirst().orElse(null);
|
||||||
|
// If no CL header, do nothing
|
||||||
|
if (contentLengthEntry != null) {
|
||||||
|
// If the provided CL value is non-zero, override it
|
||||||
|
if (contentLengthEntry.getValue() != null && !contentLengthEntry.getValue().equalsIgnoreCase("0")) {
|
||||||
|
logger.warn("This is a DELETE request; the provided Content-Length was {}; setting Content-Length to 0", contentLengthEntry.getValue());
|
||||||
|
headers.put(CONTENT_LENGTH_HEADER_KEY, "0");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Response replicate(final PreparedRequest request, final String uri) throws IOException {
|
public Response replicate(final PreparedRequest request, final String uri) throws IOException {
|
||||||
if (!(Objects.requireNonNull(request) instanceof OkHttpPreparedRequest)) {
|
if (!(Objects.requireNonNull(request) instanceof OkHttpPreparedRequest)) {
|
||||||
@ -140,7 +162,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||||||
final String contentEncoding = callResponse.header("Content-Encoding");
|
final String contentEncoding = callResponse.header("Content-Encoding");
|
||||||
if (gzipEncodings.contains(contentEncoding)) {
|
if (gzipEncodings.contains(contentEncoding)) {
|
||||||
try (final InputStream gzipIn = new GZIPInputStream(new ByteArrayInputStream(rawBytes));
|
try (final InputStream gzipIn = new GZIPInputStream(new ByteArrayInputStream(rawBytes));
|
||||||
final ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
|
final ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
|
||||||
|
|
||||||
StreamUtils.copy(gzipIn, baos);
|
StreamUtils.copy(gzipIn, baos);
|
||||||
return baos.toByteArray();
|
return baos.toByteArray();
|
||||||
@ -183,7 +205,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private HttpUrl buildUrl(final OkHttpPreparedRequest request, final String uri) {
|
private HttpUrl buildUrl(final OkHttpPreparedRequest request, final String uri) {
|
||||||
HttpUrl.Builder urlBuilder = HttpUrl.parse(uri.toString()).newBuilder();
|
HttpUrl.Builder urlBuilder = HttpUrl.parse(uri).newBuilder();
|
||||||
switch (request.getMethod().toUpperCase()) {
|
switch (request.getMethod().toUpperCase()) {
|
||||||
case HttpMethod.DELETE:
|
case HttpMethod.DELETE:
|
||||||
case HttpMethod.HEAD:
|
case HttpMethod.HEAD:
|
||||||
@ -226,7 +248,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||||||
|
|
||||||
private byte[] serializeEntity(final Object entity, final String contentType, final boolean gzip) {
|
private byte[] serializeEntity(final Object entity, final String contentType, final boolean gzip) {
|
||||||
try (final ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
try (final ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
final OutputStream out = gzip ? new GZIPOutputStream(baos, 1) : baos) {
|
final OutputStream out = gzip ? new GZIPOutputStream(baos, 1) : baos) {
|
||||||
|
|
||||||
getSerializer(contentType).serialize(entity, out);
|
getSerializer(contentType).serialize(entity, out);
|
||||||
out.close();
|
out.close();
|
||||||
@ -269,10 +291,10 @@ public class OkHttpReplicationClient implements HttpReplicationClient {
|
|||||||
} else {
|
} else {
|
||||||
final String[] acceptEncodingTokens = rawAcceptEncoding.split(",");
|
final String[] acceptEncodingTokens = rawAcceptEncoding.split(",");
|
||||||
return Stream.of(acceptEncodingTokens)
|
return Stream.of(acceptEncodingTokens)
|
||||||
.map(String::trim)
|
.map(String::trim)
|
||||||
.filter(StringUtils::isNotEmpty)
|
.filter(StringUtils::isNotEmpty)
|
||||||
.map(String::toLowerCase)
|
.map(String::toLowerCase)
|
||||||
.anyMatch(gzipEncodings::contains);
|
.anyMatch(gzipEncodings::contains);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,138 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
package org.apache.nifi.cluster.coordination.http.replication.okhttp
|
||||||
|
|
||||||
|
import org.apache.nifi.properties.StandardNiFiProperties
|
||||||
|
import org.apache.nifi.util.NiFiProperties
|
||||||
|
import org.junit.BeforeClass
|
||||||
|
import org.junit.Test
|
||||||
|
import org.junit.runner.RunWith
|
||||||
|
import org.junit.runners.JUnit4
|
||||||
|
import org.slf4j.Logger
|
||||||
|
import org.slf4j.LoggerFactory
|
||||||
|
|
||||||
|
@RunWith(JUnit4.class)
|
||||||
|
class OkHttpReplicationClientTest extends GroovyTestCase {
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClientTest.class)
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
static void setUpOnce() throws Exception {
|
||||||
|
logger.metaClass.methodMissing = { String name, args ->
|
||||||
|
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static StandardNiFiProperties mockNiFiProperties() {
|
||||||
|
[getClusterNodeConnectionTimeout: { -> "10 ms" },
|
||||||
|
getClusterNodeReadTimeout : { -> "10 ms" },
|
||||||
|
getProperty : { String prop ->
|
||||||
|
logger.mock("Requested getProperty(${prop}) -> \"\"")
|
||||||
|
""
|
||||||
|
}] as StandardNiFiProperties
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testShouldReplaceNonZeroContentLengthHeader() {
|
||||||
|
// Arrange
|
||||||
|
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
|
||||||
|
String method = "DELETE"
|
||||||
|
logger.info("Original headers: ${headers}")
|
||||||
|
|
||||||
|
NiFiProperties mockProperties = mockNiFiProperties()
|
||||||
|
|
||||||
|
OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)
|
||||||
|
|
||||||
|
// Act
|
||||||
|
client.checkContentLengthHeader(method, headers)
|
||||||
|
logger.info("Checked headers: ${headers}")
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert headers.size() == 2
|
||||||
|
assert headers."Content-Length" == "0"
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testShouldReplaceNonZeroContentLengthHeaderOnDeleteCaseInsensitive() {
|
||||||
|
// Arrange
|
||||||
|
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
|
||||||
|
String method = "delete"
|
||||||
|
logger.info("Original headers: ${headers}")
|
||||||
|
|
||||||
|
NiFiProperties mockProperties = mockNiFiProperties()
|
||||||
|
|
||||||
|
OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)
|
||||||
|
|
||||||
|
// Act
|
||||||
|
client.checkContentLengthHeader(method, headers)
|
||||||
|
logger.info("Checked headers: ${headers}")
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert headers.size() == 2
|
||||||
|
assert headers."Content-Length" == "0"
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testShouldNotReplaceContentLengthHeaderWhenZeroOrNull() {
|
||||||
|
// Arrange
|
||||||
|
String method = "DELETE"
|
||||||
|
def zeroOrNullContentLengths = [null, "0"]
|
||||||
|
|
||||||
|
NiFiProperties mockProperties = mockNiFiProperties()
|
||||||
|
|
||||||
|
OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)
|
||||||
|
|
||||||
|
// Act
|
||||||
|
zeroOrNullContentLengths.each { String contentLength ->
|
||||||
|
def headers = ["Content-Length": contentLength, "Other-Header": "arbitrary value"]
|
||||||
|
logger.info("Original headers: ${headers}")
|
||||||
|
|
||||||
|
logger.info("Trying method ${method}")
|
||||||
|
client.checkContentLengthHeader(method, headers)
|
||||||
|
logger.info("Checked headers: ${headers}")
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert headers.size() == 2
|
||||||
|
assert headers."Content-Length" == contentLength
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testShouldNotReplaceNonZeroContentLengthHeaderOnOtherMethod() {
|
||||||
|
// Arrange
|
||||||
|
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
|
||||||
|
logger.info("Original headers: ${headers}")
|
||||||
|
|
||||||
|
NiFiProperties mockProperties = mockNiFiProperties()
|
||||||
|
|
||||||
|
OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)
|
||||||
|
|
||||||
|
def nonDeleteMethods = ["POST", "PUT", "GET", "HEAD"]
|
||||||
|
|
||||||
|
// Act
|
||||||
|
nonDeleteMethods.each { String method ->
|
||||||
|
logger.info("Trying method ${method}")
|
||||||
|
client.checkContentLengthHeader(method, headers)
|
||||||
|
logger.info("Checked headers: ${headers}")
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert headers.size() == 2
|
||||||
|
assert headers."Content-Length" == "123"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user