NIFI-13799 Improved Replicated Cluster Response Handling (#9312)

- Return the remote Response Stream in the Replicated Response for unknown or large content length values
- Buffered smaller responses
- Addressed code warnings in ThreadPoolRequestReplicator
This commit is contained in:
David Handermann 2024-09-25 14:03:49 -05:00 committed by GitHub
parent 1fb8498c87
commit 2e7a39d200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 135 additions and 68 deletions

View File

@ -106,7 +106,6 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
private final EventReporter eventReporter;
private final RequestCompletionCallback callback;
private final ClusterCoordinator clusterCoordinator;
private final NiFiProperties nifiProperties;
private final ThreadPoolExecutor executorService;
private final ScheduledExecutorService maintenanceExecutor;
@ -145,7 +144,6 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
this.responseMapper = new StandardHttpResponseMapper(nifiProperties);
this.eventReporter = eventReporter;
this.callback = callback;
this.nifiProperties = nifiProperties;
this.httpClient = client;
final AtomicInteger threadId = new AtomicInteger(0);
@ -468,7 +466,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
final Function<NodeIdentifier, NodeHttpRequest> requestFactory =
nodeId -> new NodeHttpRequest(request, nodeId, createURI(uri, nodeId), nodeCompletionCallback, finalResponse);
submitAsyncRequest(nodeIds, uri.getScheme(), uri.getPath(), requestFactory, updatedHeaders);
submitAsyncRequest(nodeIds, requestFactory);
return response;
} catch (final Throwable t) {
@ -541,17 +539,14 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
try {
final Map<String, String> cancelLockHeaders = new HashMap<>(headers);
cancelLockHeaders.put(RequestReplicationHeader.CANCEL_TRANSACTION.getHeader(), "true");
final Thread cancelLockThread = new Thread(new Runnable() {
@Override
public void run() {
logger.debug("Found {} dissenting nodes for {} {}; canceling claim request", dissentingCount, method, uri.getPath());
final Thread cancelLockThread = new Thread(() -> {
logger.debug("Found {} dissenting nodes for {} {}; canceling claim request", dissentingCount, method, uri.getPath());
final PreparedRequest request = httpClient.prepareRequest(method, cancelLockHeaders, entity);
final Function<NodeIdentifier, NodeHttpRequest> requestFactory =
nodeId -> new NodeHttpRequest(request, nodeId, createURI(uri, nodeId), null, clusterResponse);
final PreparedRequest request = httpClient.prepareRequest(method, cancelLockHeaders, entity);
final Function<NodeIdentifier, NodeHttpRequest> requestFactory =
nodeId -> new NodeHttpRequest(request, nodeId, createURI(uri, nodeId), null, clusterResponse);
submitAsyncRequest(nodeIds, uri.getScheme(), uri.getPath(), requestFactory, cancelLockHeaders);
}
submitAsyncRequest(nodeIds, requestFactory);
});
cancelLockThread.setName("Cancel Flow Locks");
cancelLockThread.start();
@ -627,30 +622,23 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
nodeId -> new NodeHttpRequest(request, nodeId, createURI(uri, nodeId), completionCallback, clusterResponse);
// replicate the 'verification request' to all nodes
submitAsyncRequest(nodeIds, uri.getScheme(), uri.getPath(), requestFactory, validationHeaders);
submitAsyncRequest(nodeIds, requestFactory);
}
@Override
public AsyncClusterResponse getClusterResponse(final String identifier) {
final AsyncClusterResponse response = responseMap.get(identifier);
if (response == null) {
return null;
}
return response;
return responseMap.get(identifier);
}
// Visible for testing - overriding this method makes it easy to verify behavior without actually making any web requests
protected NodeResponse replicateRequest(final PreparedRequest request, final NodeIdentifier nodeId, final URI uri, final String requestId,
final StandardAsyncClusterResponse clusterResponse) throws IOException {
final Response response;
final long startNanos = System.nanoTime();
logger.debug("Replicating request to {} {}, request ID = {}, headers = {}", request.getMethod(), uri, requestId, request.getHeaders());
// invoke the request
response = httpClient.replicate(request, uri);
final Response response = httpClient.replicate(request, uri);
final long nanos = System.nanoTime() - startNanos;
clusterResponse.addTiming("Perform HTTP Request", nodeId.toString(), nanos);
@ -669,14 +657,10 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
}
private boolean isMutableRequest(final String method) {
switch (method.toUpperCase()) {
case HttpMethod.GET:
case HttpMethod.HEAD:
case HttpMethod.OPTIONS:
return false;
default:
return true;
}
return switch (method.toUpperCase()) {
case HttpMethod.GET, HttpMethod.HEAD, HttpMethod.OPTIONS -> false;
default -> true;
};
}
private boolean isDeleteComponent(final String method, final String uriPath) {
@ -689,7 +673,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
// This is because we do need to allow deletion of asynchronous requests, such as updating parameters, querying provenance, etc.
// which create a request, poll until the request completes, and then deletes it. Additionally, we want to allow terminating
// Processors, which is done by issuing a request to DELETE /processors/<id>/threads
final boolean componentUri = ConnectionEndpointMerger.CONNECTION_URI_PATTERN.matcher(uriPath).matches()
return ConnectionEndpointMerger.CONNECTION_URI_PATTERN.matcher(uriPath).matches()
|| ProcessorEndpointMerger.PROCESSOR_URI_PATTERN.matcher(uriPath).matches()
|| FunnelEndpointMerger.FUNNEL_URI_PATTERN.matcher(uriPath).matches()
|| PortEndpointMerger.INPUT_PORT_URI_PATTERN.matcher(uriPath).matches()
@ -704,8 +688,6 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
|| ParameterProviderEndpointMerger.PARAMETER_PROVIDER_URI_PATTERN.matcher(uriPath).matches()
|| FlowRegistryClientEndpointMerger.CONTROLLER_REGISTRY_URI_PATTERN.matcher(uriPath).matches()
|| SNIPPET_URI_PATTERN.matcher(uriPath).matches();
return componentUri;
}
/**
@ -754,18 +736,20 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
*/
private void onCompletedResponse(final String requestId) {
final AsyncClusterResponse response = responseMap.get(requestId);
if (response == null) {
logger.info("Replicated Request [{}] not found", requestId);
return;
}
if (response != null && callback != null) {
if (callback != null) {
try {
callback.afterRequest(response.getURIPath(), response.getMethod(), response.getCompletedNodeResponses());
} catch (final Exception e) {
logger.warn("Completed request {} {} but failed to properly handle the Request Completion Callback due to {}",
response.getMethod(), response.getURIPath(), e.toString());
logger.warn("", e);
logger.warn("Completed request {} {} but failed to properly handle the Request Completion Callback", response.getMethod(), response.getURIPath(), e);
}
}
if (response != null && logger.isDebugEnabled()) {
if (logger.isDebugEnabled()) {
logTimingInfo(response);
}
@ -811,8 +795,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
}
private void submitAsyncRequest(final Set<NodeIdentifier> nodeIds, final String scheme, final String path,
final Function<NodeIdentifier, NodeHttpRequest> callableFactory, final Map<String, String> headers) {
private void submitAsyncRequest(final Set<NodeIdentifier> nodeIds, final Function<NodeIdentifier, NodeHttpRequest> callableFactory) {
if (nodeIds.isEmpty()) {
return; // return quickly for trivial case
@ -887,7 +870,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
}
}
private static interface NodeRequestCompletionCallback {
private interface NodeRequestCompletionCallback {
void onCompletion(NodeResponse nodeResponse);
}
@ -895,10 +878,10 @@ public class ThreadPoolRequestReplicator implements RequestReplicator, Closeable
final Set<String> expiredRequestIds = ThreadPoolRequestReplicator.this.responseMap.entrySet().stream()
.filter(entry -> entry.getValue().isOlderThan(30, TimeUnit.SECONDS)) // older than 30 seconds
.filter(entry -> entry.getValue().isComplete()) // is complete
.map(entry -> entry.getKey()) // get the request id
.map(Map.Entry::getKey) // get the request id
.collect(Collectors.toSet());
expiredRequestIds.forEach(id -> onResponseConsumed(id));
expiredRequestIds.forEach(this::onResponseConsumed);
return responseMap.size();
}

View File

@ -26,7 +26,7 @@ import jakarta.ws.rs.core.Response;
import org.apache.nifi.cluster.coordination.http.replication.HttpReplicationClient;
import org.apache.nifi.cluster.coordination.http.replication.PreparedRequest;
import org.apache.nifi.cluster.coordination.http.replication.io.EntitySerializer;
import org.apache.nifi.cluster.coordination.http.replication.io.JacksonResponse;
import org.apache.nifi.cluster.coordination.http.replication.io.ReplicatedResponse;
import org.apache.nifi.cluster.coordination.http.replication.io.JsonEntitySerializer;
import org.apache.nifi.cluster.coordination.http.replication.io.XmlEntitySerializer;
import org.apache.nifi.web.client.api.HttpEntityHeaders;
@ -62,6 +62,8 @@ public class StandardHttpReplicationClient implements HttpReplicationClient {
private static final Set<String> DISALLOWED_HEADERS = Set.of("connection", "content-length", "expect", "host", "upgrade");
private static final int CONTENT_LENGTH_NOT_FOUND = -1;
private static final char PSEUDO_HEADER_PREFIX = ':';
private static final String GZIP_ENCODING = "gzip";
@ -199,17 +201,25 @@ public class StandardHttpReplicationClient implements HttpReplicationClient {
private Response replicate(final HttpRequestBodySpec httpRequestBodySpec, final String method, final URI location) throws IOException {
final long started = System.currentTimeMillis();
try (HttpResponseEntity responseEntity = httpRequestBodySpec.retrieve()) {
final int statusCode = responseEntity.statusCode();
final HttpEntityHeaders headers = responseEntity.headers();
final MultivaluedMap<String, String> responseHeaders = getResponseHeaders(headers);
final byte[] responseBody = getResponseBody(responseEntity.body(), headers);
final HttpResponseEntity responseEntity = httpRequestBodySpec.retrieve();
final int statusCode = responseEntity.statusCode();
final HttpEntityHeaders headers = responseEntity.headers();
final MultivaluedMap<String, String> responseHeaders = getResponseHeaders(headers);
final int contentLength = getContentLength(headers);
final long elapsed = System.currentTimeMillis() - started;
logger.debug("Replicated {} {} HTTP {} in {} ms", method, location, statusCode, elapsed);
final InputStream responseBody = getResponseBody(responseEntity.body(), headers);
final Runnable closeCallback = () -> {
try {
responseEntity.close();
} catch (final IOException e) {
logger.warn("Close failed for Replicated {} {} HTTP {}", method, location, statusCode, e);
}
};
return new JacksonResponse(objectMapper, responseBody, responseHeaders, location, statusCode, null);
}
final long elapsed = System.currentTimeMillis() - started;
logger.debug("Replicated {} {} HTTP {} in {} ms", method, location, statusCode, elapsed);
return new ReplicatedResponse(objectMapper, responseBody, responseHeaders, location, statusCode, contentLength, closeCallback);
}
private URI getRequestUri(final StandardPreparedRequest preparedRequest, final URI location) {
@ -288,14 +298,32 @@ public class StandardHttpReplicationClient implements HttpReplicationClient {
return headers;
}
private byte[] getResponseBody(final InputStream inputStream, final HttpEntityHeaders responseHeaders) throws IOException {
private InputStream getResponseBody(final InputStream inputStream, final HttpEntityHeaders responseHeaders) throws IOException {
final boolean gzipEncoded = isGzipEncoded(responseHeaders);
return gzipEncoded ? new GZIPInputStream(inputStream) : inputStream;
}
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
try (InputStream responseBodyStream = gzipEncoded ? new GZIPInputStream(inputStream) : inputStream) {
responseBodyStream.transferTo(outputStream);
private int getContentLength(final HttpEntityHeaders headers) {
final Optional<String> contentLengthFound = headers.getHeaderNames()
.stream()
.filter(PreparedRequestHeader.CONTENT_LENGTH.getHeader()::equalsIgnoreCase)
.findFirst()
.flatMap(headers::getFirstHeader);
int contentLength;
if (contentLengthFound.isPresent()) {
final String contentLengthHeader = contentLengthFound.get();
try {
contentLength = Integer.parseInt(contentLengthHeader);
} catch (final NumberFormatException e) {
logger.warn("Replicated Header Content-Length [{}] parsing failed", contentLengthHeader, e);
contentLength = CONTENT_LENGTH_NOT_FOUND;
}
} else {
contentLength = CONTENT_LENGTH_NOT_FOUND;
}
return outputStream.toByteArray();
return contentLength;
}
private byte[] getRequestBody(final Object requestEntity, final Map<String, String> headers) {

View File

@ -18,7 +18,9 @@
package org.apache.nifi.cluster.coordination.http.replication.io;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.annotation.Annotation;
import java.net.URI;
import java.nio.charset.StandardCharsets;
@ -43,24 +45,49 @@ import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.ObjectMapper;
public class JacksonResponse extends Response {
/**
* Replicated extension of standard Response with HTTP properties returned
*/
public class ReplicatedResponse extends Response {
private static final int MAXIMUM_BUFFER_SIZE = 1048576;
private static final int CONTENT_LENGTH_UNKNOWN = -1;
private final ObjectMapper codec;
private final byte[] responseBody;
private final InputStream responseBody;
private final MultivaluedMap<String, String> responseHeaders;
private final URI location;
private final int statusCode;
private final Runnable closeCallback;
private final int contentLength;
private final JsonFactory jsonFactory = new JsonFactory();
public JacksonResponse(final ObjectMapper codec, final byte[] responseBody, final MultivaluedMap<String, String> responseHeaders, final URI location, final int statusCode,
final Runnable closeCallback) {
private final byte[] bufferedResponseBody;
public ReplicatedResponse(
final ObjectMapper codec,
final InputStream responseBody,
final MultivaluedMap<String, String> responseHeaders,
final URI location,
final int statusCode,
final int contentLength,
final Runnable closeCallback
) {
this.codec = codec;
this.responseBody = responseBody;
this.responseHeaders = responseHeaders;
this.location = location;
this.statusCode = statusCode;
this.closeCallback = closeCallback;
if (contentLength == CONTENT_LENGTH_UNKNOWN || contentLength > MAXIMUM_BUFFER_SIZE) {
// Avoid buffering unknown Content-Length or greater than maximum buffer size specified
bufferedResponseBody = null;
this.contentLength = CONTENT_LENGTH_UNKNOWN;
} else {
bufferedResponseBody = readResponseBody(responseBody, location, statusCode);
this.contentLength = bufferedResponseBody.length;
}
}
@Override
@ -75,8 +102,10 @@ public class JacksonResponse extends Response {
@Override
public Object getEntity() {
final InputStream responseBodyStream = getResponseBodyStream();
try {
final JsonParser parser = jsonFactory.createParser(responseBody);
final JsonParser parser = jsonFactory.createParser(responseBodyStream);
parser.setCodec(codec);
return parser.readValueAs(Object.class);
} catch (final Exception e) {
@ -87,16 +116,23 @@ public class JacksonResponse extends Response {
@Override
@SuppressWarnings("unchecked")
public <T> T readEntity(Class<T> entityType) {
final InputStream responseBodyStream = getResponseBodyStream();
if (InputStream.class.equals(entityType)) {
return (T) new ByteArrayInputStream(responseBody);
return (T) responseBodyStream;
}
if (String.class.equals(entityType)) {
return (T) new String(responseBody, StandardCharsets.UTF_8);
try {
final byte[] responseBytes = responseBodyStream.readAllBytes();
return (T) new String(responseBytes, StandardCharsets.UTF_8);
} catch (final IOException e) {
throw new UncheckedIOException("Read Replicated Response Body to String failed for %s".formatted(location), e);
}
}
try {
final JsonParser parser = jsonFactory.createParser(responseBody);
final JsonParser parser = jsonFactory.createParser(responseBodyStream);
parser.setCodec(codec);
return parser.readValueAs(entityType);
} catch (final Exception e) {
@ -121,7 +157,7 @@ public class JacksonResponse extends Response {
@Override
public boolean hasEntity() {
return responseBody != null && responseBody.length > 0;
return true;
}
@Override
@ -148,7 +184,7 @@ public class JacksonResponse extends Response {
@Override
public int getLength() {
return responseBody == null ? 0 : responseBody.length;
return contentLength;
}
@Override
@ -239,4 +275,24 @@ public class JacksonResponse extends Response {
return responseHeaders.getFirst(name.toLowerCase());
}
private InputStream getResponseBodyStream() {
final InputStream responseBodyStream;
if (bufferedResponseBody == null) {
responseBodyStream = responseBody;
} else {
responseBodyStream = new ByteArrayInputStream(bufferedResponseBody);
}
return responseBodyStream;
}
private static byte[] readResponseBody(final InputStream inputStream, final URI location, final int statusCode) {
try {
return inputStream.readAllBytes();
} catch (final IOException e) {
throw new UncheckedIOException("Buffering Replicated Response Body failed %s HTTP %d".formatted(location, statusCode), e);
}
}
}