From 748cf745628dab20b7e71f12b5dcfe6ed0bbf134 Mon Sep 17 00:00:00 2001 From: Andy LoPresto Date: Wed, 26 Sep 2018 18:18:22 -0700 Subject: [PATCH] NIFI-5628 Added content length check to OkHttpReplicationClient. Added unit tests. This closes #3035 --- .../okhttp/OkHttpReplicationClient.java | 96 +++++++----- .../okhttp/OkHttpReplicationClientTest.groovy | 138 ++++++++++++++++++ 2 files changed, 197 insertions(+), 37 deletions(-) create mode 100644 nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-cluster/src/test/groovy/org/apache/nifi/cluster/coordination/http/replication/okhttp/OkHttpReplicationClientTest.groovy diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-cluster/src/main/java/org/apache/nifi/cluster/coordination/http/replication/okhttp/OkHttpReplicationClient.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-cluster/src/main/java/org/apache/nifi/cluster/coordination/http/replication/okhttp/OkHttpReplicationClient.java index b0f0a39429..ec8a2b091f 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-cluster/src/main/java/org/apache/nifi/cluster/coordination/http/replication/okhttp/OkHttpReplicationClient.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-cluster/src/main/java/org/apache/nifi/cluster/coordination/http/replication/okhttp/OkHttpReplicationClient.java @@ -21,6 +21,35 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonInclude.Value; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.module.jaxb.JaxbAnnotationIntrospector; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.security.KeyStore; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.zip.GZIPInputStream; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import javax.ws.rs.HttpMethod; +import javax.ws.rs.core.MultivaluedHashMap; +import javax.ws.rs.core.MultivaluedMap; +import javax.ws.rs.core.Response; import okhttp3.Call; import okhttp3.ConnectionPool; import okhttp3.Headers; @@ -42,36 +71,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.util.StreamUtils; -import javax.net.ssl.KeyManager; -import javax.net.ssl.KeyManagerFactory; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLSocketFactory; -import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; -import javax.net.ssl.X509TrustManager; -import javax.ws.rs.HttpMethod; -import javax.ws.rs.core.MultivaluedHashMap; -import javax.ws.rs.core.MultivaluedMap; -import javax.ws.rs.core.Response; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.net.URI; -import java.security.KeyStore; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import java.util.zip.GZIPInputStream; - public class OkHttpReplicationClient implements HttpReplicationClient { private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClient.class); private static final Set gzipEncodings = Stream.of("gzip", "x-gzip").collect(Collectors.toSet()); @@ -95,12 +94,35 @@ public class OkHttpReplicationClient implements HttpReplicationClient { @Override public PreparedRequest prepareRequest(final String method, final Map headers, final Object entity) { final boolean gzip = isUseGzip(headers); + checkContentLengthHeader(method, headers); final RequestBody requestBody = createRequestBody(headers, entity, gzip); final Map updatedHeaders = gzip ? updateHeadersForGzip(headers) : headers; return new OkHttpPreparedRequest(method, updatedHeaders, entity, requestBody); } + /** + * Checks the content length header on DELETE requests to ensure it is set to '0', avoiding request timeouts on replicated requests. + * @param method the HTTP method of the request + * @param headers the header keys and values + */ + private void checkContentLengthHeader(String method, Map headers) { + // Only applies to DELETE requests + if (HttpMethod.DELETE.equalsIgnoreCase(method)) { + // Find the Content-Length header if present + final String CONTENT_LENGTH_HEADER_KEY = "Content-Length"; + Map.Entry contentLengthEntry = headers.entrySet().stream().filter(entry -> entry.getKey().equalsIgnoreCase(CONTENT_LENGTH_HEADER_KEY)).findFirst().orElse(null); + // If no CL header, do nothing + if (contentLengthEntry != null) { + // If the provided CL value is non-zero, override it + if (contentLengthEntry.getValue() != null && !contentLengthEntry.getValue().equalsIgnoreCase("0")) { + logger.warn("This is a DELETE request; the provided Content-Length was {}; setting Content-Length to 0", contentLengthEntry.getValue()); + headers.put(CONTENT_LENGTH_HEADER_KEY, "0"); + } + } + } + } + @Override public Response replicate(final PreparedRequest request, final String uri) throws IOException { if (!(Objects.requireNonNull(request) instanceof OkHttpPreparedRequest)) { @@ -140,7 +162,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient { final String contentEncoding = callResponse.header("Content-Encoding"); if (gzipEncodings.contains(contentEncoding)) { try (final InputStream gzipIn = new GZIPInputStream(new ByteArrayInputStream(rawBytes)); - final ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream()) { StreamUtils.copy(gzipIn, baos); return baos.toByteArray(); @@ -183,7 +205,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient { @SuppressWarnings("unchecked") private HttpUrl buildUrl(final OkHttpPreparedRequest request, final String uri) { - HttpUrl.Builder urlBuilder = HttpUrl.parse(uri.toString()).newBuilder(); + HttpUrl.Builder urlBuilder = HttpUrl.parse(uri).newBuilder(); switch (request.getMethod().toUpperCase()) { case HttpMethod.DELETE: case HttpMethod.HEAD: @@ -226,7 +248,7 @@ public class OkHttpReplicationClient implements HttpReplicationClient { private byte[] serializeEntity(final Object entity, final String contentType, final boolean gzip) { try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - final OutputStream out = gzip ? new GZIPOutputStream(baos, 1) : baos) { + final OutputStream out = gzip ? new GZIPOutputStream(baos, 1) : baos) { getSerializer(contentType).serialize(entity, out); out.close(); @@ -269,10 +291,10 @@ public class OkHttpReplicationClient implements HttpReplicationClient { } else { final String[] acceptEncodingTokens = rawAcceptEncoding.split(","); return Stream.of(acceptEncodingTokens) - .map(String::trim) - .filter(StringUtils::isNotEmpty) - .map(String::toLowerCase) - .anyMatch(gzipEncodings::contains); + .map(String::trim) + .filter(StringUtils::isNotEmpty) + .map(String::toLowerCase) + .anyMatch(gzipEncodings::contains); } } diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-cluster/src/test/groovy/org/apache/nifi/cluster/coordination/http/replication/okhttp/OkHttpReplicationClientTest.groovy b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-cluster/src/test/groovy/org/apache/nifi/cluster/coordination/http/replication/okhttp/OkHttpReplicationClientTest.groovy new file mode 100644 index 0000000000..cad27f182b --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-cluster/src/test/groovy/org/apache/nifi/cluster/coordination/http/replication/okhttp/OkHttpReplicationClientTest.groovy @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.nifi.cluster.coordination.http.replication.okhttp + +import org.apache.nifi.properties.StandardNiFiProperties +import org.apache.nifi.util.NiFiProperties +import org.junit.BeforeClass +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +@RunWith(JUnit4.class) +class OkHttpReplicationClientTest extends GroovyTestCase { + private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClientTest.class) + + @BeforeClass + static void setUpOnce() throws Exception { + logger.metaClass.methodMissing = { String name, args -> + logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") + } + } + + private static StandardNiFiProperties mockNiFiProperties() { + [getClusterNodeConnectionTimeout: { -> "10 ms" }, + getClusterNodeReadTimeout : { -> "10 ms" }, + getProperty : { String prop -> + logger.mock("Requested getProperty(${prop}) -> \"\"") + "" + }] as StandardNiFiProperties + } + + @Test + void testShouldReplaceNonZeroContentLengthHeader() { + // Arrange + def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"] + String method = "DELETE" + logger.info("Original headers: ${headers}") + + NiFiProperties mockProperties = mockNiFiProperties() + + OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties) + + // Act + client.checkContentLengthHeader(method, headers) + logger.info("Checked headers: ${headers}") + + // Assert + assert headers.size() == 2 + assert headers."Content-Length" == "0" + } + + @Test + void testShouldReplaceNonZeroContentLengthHeaderOnDeleteCaseInsensitive() { + // Arrange + def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"] + String method = "delete" + logger.info("Original headers: ${headers}") + + NiFiProperties mockProperties = mockNiFiProperties() + + OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties) + + // Act + client.checkContentLengthHeader(method, headers) + logger.info("Checked headers: ${headers}") + + // Assert + assert headers.size() == 2 + assert headers."Content-Length" == "0" + } + + @Test + void testShouldNotReplaceContentLengthHeaderWhenZeroOrNull() { + // Arrange + String method = "DELETE" + def zeroOrNullContentLengths = [null, "0"] + + NiFiProperties mockProperties = mockNiFiProperties() + + OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties) + + // Act + zeroOrNullContentLengths.each { String contentLength -> + def headers = ["Content-Length": contentLength, "Other-Header": "arbitrary value"] + logger.info("Original headers: ${headers}") + + logger.info("Trying method ${method}") + client.checkContentLengthHeader(method, headers) + logger.info("Checked headers: ${headers}") + + // Assert + assert headers.size() == 2 + assert headers."Content-Length" == contentLength + } + } + + @Test + void testShouldNotReplaceNonZeroContentLengthHeaderOnOtherMethod() { + // Arrange + def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"] + logger.info("Original headers: ${headers}") + + NiFiProperties mockProperties = mockNiFiProperties() + + OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties) + + def nonDeleteMethods = ["POST", "PUT", "GET", "HEAD"] + + // Act + nonDeleteMethods.each { String method -> + logger.info("Trying method ${method}") + client.checkContentLengthHeader(method, headers) + logger.info("Checked headers: ${headers}") + + // Assert + assert headers.size() == 2 + assert headers."Content-Length" == "123" + } + } +}