From 619028c33e09dcf8012bef6e3acbca279fea4fe3 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Tue, 7 Apr 2020 17:10:31 -0600 Subject: [PATCH] Implement transport circuit breaking in aggregator (#54927) This commit moves the action name validation and circuit breaking into the InboundAggregator. This work is valuable because it lays the groundwork for incrementally circuit breaking as data is received. This PR includes the follow behavioral change: Handshakes contribute to circuit breaking, but cannot be broken. They currently do not contribute nor are they broken. --- .../netty4/Netty4MessageChannelHandler.java | 4 +- .../transport/nio/TcpReadWriteHandler.java | 7 +- .../transport/InboundAggregator.java | 158 +++++++++++++++--- .../transport/InboundDecoder.java | 23 +-- .../transport/InboundHandler.java | 120 +++++++------ .../transport/InboundMessage.java | 32 +++- .../transport/InboundPipeline.java | 42 ++--- .../elasticsearch/transport/TcpTransport.java | 19 ++- .../transport/TcpTransportChannel.java | 14 +- .../transport/InboundAggregatorTests.java | 131 +++++++++++++-- .../transport/InboundDecoderTests.java | 35 +--- .../transport/InboundHandlerTests.java | 8 +- .../transport/InboundPipelineTests.java | 112 ++++++++----- .../transport/OutboundHandlerTests.java | 16 +- .../common/breaker/TestCircuitBreaker.java | 47 ++++++ .../AbstractSimpleTransportTestCase.java | 24 +-- .../transport/nio/MockNioTransport.java | 11 +- 17 files changed, 544 insertions(+), 259 deletions(-) create mode 100644 test/framework/src/main/java/org/elasticsearch/common/breaker/TestCircuitBreaker.java diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java index 5f03d08b4c6..cd2ba31c13d 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java @@ -32,6 +32,7 @@ import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.InboundHandler; import org.elasticsearch.transport.InboundPipeline; import org.elasticsearch.transport.Transports; @@ -55,8 +56,9 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler { Netty4MessageChannelHandler(PageCacheRecycler recycler, Netty4Transport transport) { this.transport = transport; final ThreadPool threadPool = transport.getThreadPool(); + final InboundHandler inboundHandler = transport.getInboundHandler(); this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis, - transport::inboundMessage, transport::inboundDecodeException); + transport.getInflightBreaker(), inboundHandler::getRequestHandler, transport::inboundMessage); } @Override diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java index bf3473a16ce..d4bd764cb2c 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpReadWriteHandler.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport.nio; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -30,10 +31,12 @@ import org.elasticsearch.nio.BytesWriteHandler; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.Page; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.InboundHandler; import org.elasticsearch.transport.InboundPipeline; import org.elasticsearch.transport.TcpTransport; import java.io.IOException; +import java.util.function.Supplier; public class TcpReadWriteHandler extends BytesWriteHandler { @@ -43,8 +46,10 @@ public class TcpReadWriteHandler extends BytesWriteHandler { public TcpReadWriteHandler(NioTcpChannel channel, PageCacheRecycler recycler, TcpTransport transport) { this.channel = channel; final ThreadPool threadPool = transport.getThreadPool(); + final Supplier breaker = transport.getInflightBreaker(); + final InboundHandler inboundHandler = transport.getInboundHandler(); this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, threadPool::relativeTimeInMillis, - transport::inboundMessage, transport::inboundDecodeException); + breaker, inboundHandler::getRequestHandler, transport::inboundMessage); } @Override diff --git a/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java b/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java index 970747fc232..4516bfe8b16 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundAggregator.java @@ -19,6 +19,8 @@ package org.elasticsearch.transport; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -27,44 +29,69 @@ import org.elasticsearch.common.lease.Releasables; import java.io.IOException; import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; public class InboundAggregator implements Releasable { + private final Supplier circuitBreaker; + private final Predicate requestCanTripBreaker; + private ReleasableBytesReference firstContent; private ArrayList contentAggregation; private Header currentHeader; + private Exception aggregationException; + private boolean canTripBreaker = true; private boolean isClosed = false; + public InboundAggregator(Supplier circuitBreaker, + Function> registryFunction) { + this(circuitBreaker, (Predicate) actionName -> { + final RequestHandlerRegistry reg = registryFunction.apply(actionName); + if (reg == null) { + throw new ActionNotFoundTransportException(actionName); + } else { + return reg.canTripCircuitBreaker(); + } + }); + } + + // Visible for testing + InboundAggregator(Supplier circuitBreaker, Predicate requestCanTripBreaker) { + this.circuitBreaker = circuitBreaker; + this.requestCanTripBreaker = requestCanTripBreaker; + } + public void headerReceived(Header header) { ensureOpen(); assert isAggregating() == false; assert firstContent == null && contentAggregation == null; currentHeader = header; + if (currentHeader.isRequest() && currentHeader.needsToReadVariableHeader() == false) { + initializeRequestState(); + } } public void aggregate(ReleasableBytesReference content) { ensureOpen(); assert isAggregating(); - if (isFirstContent()) { - firstContent = content.retain(); - } else { - if (contentAggregation == null) { - contentAggregation = new ArrayList<>(4); - contentAggregation.add(firstContent); - firstContent = null; + if (isShortCircuited() == false) { + if (isFirstContent()) { + firstContent = content.retain(); + } else { + if (contentAggregation == null) { + contentAggregation = new ArrayList<>(4); + assert firstContent != null; + contentAggregation.add(firstContent); + firstContent = null; + } + contentAggregation.add(content.retain()); } - contentAggregation.add(content.retain()); } } - public Header cancelAggregation() { - ensureOpen(); - assert isAggregating(); - final Header header = this.currentHeader; - closeCurrentAggregation(); - return header; - } - public InboundMessage finishAggregation() throws IOException { ensureOpen(); final ReleasableBytesReference releasableContent; @@ -77,16 +104,30 @@ public class InboundAggregator implements Releasable { final CompositeBytesReference content = new CompositeBytesReference(references); releasableContent = new ReleasableBytesReference(content, () -> Releasables.close(references)); } - final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent); - resetCurrentAggregation(); + + final BreakerControl breakerControl = new BreakerControl(circuitBreaker); + final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl); boolean success = false; try { if (aggregated.getHeader().needsToReadVariableHeader()) { aggregated.getHeader().finishParsingHeader(aggregated.openOrGetStreamInput()); + if (aggregated.getHeader().isRequest()) { + initializeRequestState(); + } + } + if (isShortCircuited() == false) { + checkBreaker(aggregated.getHeader(), aggregated.getContentLength(), breakerControl); + } + if (isShortCircuited()) { + aggregated.close(); + success = true; + return new InboundMessage(aggregated.getHeader(), aggregationException); + } else { + success = true; + return aggregated; } - success = true; - return aggregated; } finally { + resetCurrentAggregation(); if (success == false) { aggregated.close(); } @@ -97,6 +138,14 @@ public class InboundAggregator implements Releasable { return currentHeader != null; } + private void shortCircuit(Exception exception) { + this.aggregationException = exception; + } + + private boolean isShortCircuited() { + return aggregationException != null; + } + private boolean isFirstContent() { return firstContent == null && contentAggregation == null; } @@ -108,18 +157,24 @@ public class InboundAggregator implements Releasable { } private void closeCurrentAggregation() { + releaseContent(); + resetCurrentAggregation(); + } + + private void releaseContent() { if (contentAggregation == null) { Releasables.close(firstContent); } else { Releasables.close(contentAggregation); } - resetCurrentAggregation(); } private void resetCurrentAggregation() { firstContent = null; contentAggregation = null; currentHeader = null; + aggregationException = null; + canTripBreaker = true; } private void ensureOpen() { @@ -127,4 +182,65 @@ public class InboundAggregator implements Releasable { throw new IllegalStateException("Aggregator is already closed"); } } + + private void initializeRequestState() { + assert currentHeader.needsToReadVariableHeader() == false; + assert currentHeader.isRequest(); + if (currentHeader.isHandshake()) { + canTripBreaker = false; + return; + } + + final String actionName = currentHeader.getActionName(); + try { + canTripBreaker = requestCanTripBreaker.test(actionName); + } catch (ActionNotFoundTransportException e) { + shortCircuit(e); + } + } + + private void checkBreaker(final Header header, final int contentLength, final BreakerControl breakerControl) { + if (header.isRequest() == false) { + return; + } + assert header.needsToReadVariableHeader() == false; + + if (canTripBreaker) { + try { + circuitBreaker.get().addEstimateBytesAndMaybeBreak(contentLength, header.getActionName()); + breakerControl.setReservedBytes(contentLength); + } catch (CircuitBreakingException e) { + shortCircuit(e); + } + } else { + circuitBreaker.get().addWithoutBreaking(contentLength); + breakerControl.setReservedBytes(contentLength); + } + } + + private static class BreakerControl implements Releasable { + + private static final int CLOSED = -1; + + private final Supplier circuitBreaker; + private final AtomicInteger bytesToRelease = new AtomicInteger(0); + + private BreakerControl(Supplier circuitBreaker) { + this.circuitBreaker = circuitBreaker; + } + + private void setReservedBytes(int reservedBytes) { + final boolean set = bytesToRelease.compareAndSet(0, reservedBytes); + assert set : "Expected bytesToRelease to be 0, found " + bytesToRelease.get(); + } + + @Override + public void close() { + final int toRelease = bytesToRelease.getAndSet(CLOSED); + assert toRelease != CLOSED; + if (toRelease > 0) { + circuitBreaker.get().addWithoutBreaking(-toRelease); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java index 48888a68a67..ef35251ffe8 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java @@ -37,7 +37,6 @@ public class InboundDecoder implements Releasable { private final Version version; private final PageCacheRecycler recycler; - private Exception decodingException; private TransportDecompressor decompressor; private int totalNetworkSize = -1; private int bytesConsumed = 0; @@ -86,13 +85,6 @@ public class InboundDecoder implements Releasable { return headerBytesToRead; } } - } else if (isDecodingFailed()) { - int bytesToConsume = Math.min(reference.length(), totalNetworkSize - bytesConsumed); - bytesConsumed += bytesToConsume; - if (isDone()) { - finishMessage(fragmentConsumer); - } - return bytesToConsume; } else { // There are a minimum number of bytes required to start decompression if (decompressor != null && decompressor.canDecompress(reference.length()) == false) { @@ -130,19 +122,12 @@ public class InboundDecoder implements Releasable { } private void finishMessage(Consumer fragmentConsumer) { - Object finishMarker; - if (decodingException != null) { - finishMarker = decodingException; - } else { - finishMarker = END_CONTENT; - } cleanDecodeState(); - fragmentConsumer.accept(finishMarker); + fragmentConsumer.accept(END_CONTENT); } private void cleanDecodeState() { IOUtils.closeWhileHandlingException(decompressor); - decodingException = null; decompressor = null; totalNetworkSize = -1; bytesConsumed = 0; @@ -190,7 +175,7 @@ public class InboundDecoder implements Releasable { Header header = new Header(networkMessageSize, requestId, status, remoteVersion); final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake()); if (invalidVersion != null) { - decodingException = invalidVersion; + throw invalidVersion; } else { if (remoteVersion.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) { // Skip since we already have ensured enough data available @@ -206,10 +191,6 @@ public class InboundDecoder implements Releasable { return totalNetworkSize == -1; } - private boolean isDecodingFailed() { - return decodingException != null; - } - private void ensureOpen() { if (isClosed) { throw new IllegalStateException("Decoder is already closed"); diff --git a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java index 89002f259a7..d87cbb697f7 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java @@ -23,7 +23,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.Version; -import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -31,14 +30,12 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; import java.net.InetSocketAddress; import java.util.Collections; import java.util.Map; -import java.util.Set; public class InboundHandler { @@ -47,7 +44,6 @@ public class InboundHandler { private final ThreadPool threadPool; private final OutboundHandler outboundHandler; private final NamedWriteableRegistry namedWriteableRegistry; - private final CircuitBreakerService circuitBreakerService; private final TransportHandshaker handshaker; private final TransportKeepAlive keepAlive; @@ -56,11 +52,10 @@ public class InboundHandler { private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; InboundHandler(ThreadPool threadPool, OutboundHandler outboundHandler, NamedWriteableRegistry namedWriteableRegistry, - CircuitBreakerService circuitBreakerService, TransportHandshaker handshaker, TransportKeepAlive keepAlive) { + TransportHandshaker handshaker, TransportKeepAlive keepAlive) { this.threadPool = threadPool; this.outboundHandler = outboundHandler; this.namedWriteableRegistry = namedWriteableRegistry; - this.circuitBreakerService = circuitBreakerService; this.handshaker = handshaker; this.keepAlive = keepAlive; } @@ -73,7 +68,7 @@ public class InboundHandler { } @SuppressWarnings("unchecked") - final RequestHandlerRegistry getRequestHandler(String action) { + public final RequestHandlerRegistry getRequestHandler(String action) { return (RequestHandlerRegistry) requestHandlers.get(action); } @@ -96,26 +91,27 @@ public class InboundHandler { if (message.isPing()) { keepAlive.receiveKeepAlive(channel); } else { - messageReceived(message, channel); + messageReceived(channel, message); } } - private void messageReceived(InboundMessage message, TcpChannel channel) throws IOException { + private void messageReceived(TcpChannel channel, InboundMessage message) throws IOException { final InetSocketAddress remoteAddress = channel.getRemoteAddress(); final Header header = message.getHeader(); assert header.needsToReadVariableHeader() == false; - final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput()); - assertRemoteVersion(streamInput, header.getVersion()); - ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext existing = threadContext.stashContext()) { // Place the context with the headers from the message threadContext.setHeaders(header.getHeaders()); threadContext.putTransient("_remote_address", remoteAddress); if (header.isRequest()) { - handleRequest(channel, header, streamInput, message.getContentLength()); + handleRequest(channel, header, message); } else { + // Responses do not support short circuiting currently + assert message.isShortCircuit() == false; + final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(streamInput, header.getVersion()); final TransportResponseHandler handler; long requestId = header.getRequestId(); if (header.isHandshake()) { @@ -148,55 +144,59 @@ public class InboundHandler { } } - private void handleRequest(TcpChannel channel, Header header, StreamInput stream, int messageLengthBytes) { + private void handleRequest(TcpChannel channel, Header header, InboundMessage message) throws IOException { final String action = header.getActionName(); final long requestId = header.getRequestId(); final Version version = header.getVersion(); - final Set features = header.getFeatures(); - TransportChannel transportChannel = null; - try { + if (header.isHandshake()) { messageListener.onRequestReceived(requestId, action); - if (header.isHandshake()) { - // Handshakes are not currently circuit broken - transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, - circuitBreakerService, 0, header.isCompressed(), header.isHandshake()); - handshaker.handleHandshake(transportChannel, requestId, stream); - } else { - final RequestHandlerRegistry reg = getRequestHandler(action); - if (reg == null) { - throw new ActionNotFoundTransportException(action); - } - CircuitBreaker breaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS); - if (reg.canTripCircuitBreaker()) { - breaker.addEstimateBytesAndMaybeBreak(messageLengthBytes, ""); - } else { - breaker.addWithoutBreaking(messageLengthBytes); - } - transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, - circuitBreakerService, messageLengthBytes, header.isCompressed(), header.isHandshake()); - final T request = reg.newRequest(stream); - request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); - // in case we throw an exception, i.e. when the limit is hit, we don't want to verify - final int nextByte = stream.read(); - // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker - if (nextByte != -1) { - throw new IllegalStateException("Message not fully read (request) for requestId [" + requestId + "], action [" + action - + "], available [" + stream.available() + "]; resetting"); - } - threadPool.executor(reg.getExecutor()).execute(new RequestHandler<>(reg, request, transportChannel)); - } - } catch (Exception e) { - // the circuit breaker tripped - if (transportChannel == null) { - transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, - circuitBreakerService, 0, header.isCompressed(), header.isHandshake()); - } + // Cannot short circuit handshakes + assert message.isShortCircuit() == false; + final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(stream, header.getVersion()); + final TransportChannel transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, + header.getFeatures(), header.isCompressed(), header.isHandshake(), message.takeBreakerReleaseControl()); try { - transportChannel.sendResponse(e); - } catch (IOException inner) { - inner.addSuppressed(e); - logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", action), inner); + handshaker.handleHandshake(transportChannel, requestId, stream); + } catch (Exception e) { + sendErrorResponse(action, transportChannel, e); } + } else { + final TransportChannel transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, + header.getFeatures(), header.isCompressed(), header.isHandshake(), message.takeBreakerReleaseControl()); + try { + messageListener.onRequestReceived(requestId, action); + if (message.isShortCircuit()) { + sendErrorResponse(action, transportChannel, message.getException()); + } else { + final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(stream, header.getVersion()); + final RequestHandlerRegistry reg = getRequestHandler(action); + assert reg != null; + final T request = reg.newRequest(stream); + request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); + // in case we throw an exception, i.e. when the limit is hit, we don't want to verify + final int nextByte = stream.read(); + // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker + if (nextByte != -1) { + throw new IllegalStateException("Message not fully read (request) for requestId [" + requestId + "], action [" + + action + "], available [" + stream.available() + "]; resetting"); + } + threadPool.executor(reg.getExecutor()).execute(new RequestHandler<>(reg, request, transportChannel)); + } + } catch (Exception e) { + sendErrorResponse(action, transportChannel, e); + } + + } + } + + private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) { + try { + transportChannel.sendResponse(e); + } catch (Exception inner) { + inner.addSuppressed(e); + logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", actionName), inner); } } @@ -279,13 +279,7 @@ public class InboundHandler { @Override public void onFailure(Exception e) { - try { - transportChannel.sendResponse(e); - } catch (Exception inner) { - inner.addSuppressed(e); - logger.warn(() -> new ParameterizedMessage( - "Failed to send error message back to client for action [{}]", reg.getAction()), inner); - } + sendErrorResponse(reg.getAction(), transportChannel, e); } } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java index b8f1dfa14a1..99dd23e940d 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java @@ -31,18 +31,32 @@ public class InboundMessage implements Releasable { private final Header header; private final ReleasableBytesReference content; + private final Exception exception; private final boolean isPing; + private Releasable breakerRelease; private StreamInput streamInput; - public InboundMessage(Header header, ReleasableBytesReference content) { + public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { this.header = header; this.content = content; + this.breakerRelease = breakerRelease; + this.exception = null; + this.isPing = false; + } + + public InboundMessage(Header header, Exception exception) { + this.header = header; + this.content = null; + this.breakerRelease = null; + this.exception = exception; this.isPing = false; } public InboundMessage(Header header, boolean isPing) { this.header = header; this.content = null; + this.breakerRelease = null; + this.exception = null; this.isPing = isPing; } @@ -58,10 +72,24 @@ public class InboundMessage implements Releasable { } } + public Exception getException() { + return exception; + } + public boolean isPing() { return isPing; } + public boolean isShortCircuit() { + return exception != null; + } + + public Releasable takeBreakerReleaseControl() { + final Releasable toReturn = breakerRelease; + breakerRelease = null; + return toReturn; + } + public StreamInput openOrGetStreamInput() throws IOException { assert isPing == false && content != null; if (streamInput == null) { @@ -74,6 +102,6 @@ public class InboundMessage implements Releasable { @Override public void close() { IOUtils.closeWhileHandlingException(streamInput); - Releasables.closeWhileHandlingException(content); + Releasables.closeWhileHandlingException(content, breakerRelease); } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java b/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java index 68740b54742..a9e71c55b4f 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java @@ -20,9 +20,9 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.PageCacheRecycler; @@ -31,7 +31,9 @@ import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.function.BiConsumer; +import java.util.function.Function; import java.util.function.LongSupplier; +import java.util.function.Supplier; public class InboundPipeline implements Releasable { @@ -43,26 +45,25 @@ public class InboundPipeline implements Releasable { private final InboundDecoder decoder; private final InboundAggregator aggregator; private final BiConsumer messageHandler; - private final BiConsumer> errorHandler; + private Exception uncaughtException; private ArrayDeque pending = new ArrayDeque<>(2); private boolean isClosed = false; public InboundPipeline(Version version, StatsTracker statsTracker, PageCacheRecycler recycler, LongSupplier relativeTimeInMillis, - BiConsumer messageHandler, - BiConsumer> errorHandler) { - this(statsTracker, relativeTimeInMillis, new InboundDecoder(version, recycler), new InboundAggregator(), messageHandler, - errorHandler); + Supplier circuitBreaker, + Function> registryFunction, + BiConsumer messageHandler) { + this(statsTracker, relativeTimeInMillis, new InboundDecoder(version, recycler), + new InboundAggregator(circuitBreaker, registryFunction), messageHandler); } - private InboundPipeline(StatsTracker statsTracker, LongSupplier relativeTimeInMillis, InboundDecoder decoder, - InboundAggregator aggregator, BiConsumer messageHandler, - BiConsumer> errorHandler) { + public InboundPipeline(StatsTracker statsTracker, LongSupplier relativeTimeInMillis, InboundDecoder decoder, + InboundAggregator aggregator, BiConsumer messageHandler) { this.relativeTimeInMillis = relativeTimeInMillis; this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; this.messageHandler = messageHandler; - this.errorHandler = errorHandler; } @Override @@ -74,6 +75,18 @@ public class InboundPipeline implements Releasable { } public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException { + if (uncaughtException != null) { + throw new IllegalStateException("Pipeline state corrupted by uncaught exception", uncaughtException); + } + try { + doHandleBytes(channel, reference); + } catch (Exception e) { + uncaughtException = e; + throw e; + } + } + + public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException { channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); statsTracker.markBytesRead(reference.length()); pending.add(reference.retain()); @@ -128,15 +141,6 @@ public class InboundPipeline implements Releasable { statsTracker.markMessageReceived(); messageHandler.accept(channel, aggregated); } - } else if (fragment instanceof Exception) { - final Header header; - if (aggregator.isAggregating()) { - header = aggregator.cancelAggregation(); - statsTracker.markMessageReceived(); - } else { - header = null; - } - errorHandler.accept(channel, new Tuple<>(header, (Exception) fragment)); } else { assert aggregator.isAggregating(); assert fragment instanceof ReleasableBytesReference; diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index ea750c2c296..9008b9d9889 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -33,7 +33,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -86,6 +85,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Supplier; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -115,6 +115,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements protected final PageCacheRecycler pageCacheRecycler; protected final NetworkService networkService; protected final Set profileSettings; + private final CircuitBreakerService circuitBreakerService; private final ConcurrentMap profileBoundAddresses = newConcurrentMap(); private final Map> serverChannels = newConcurrentMap(); @@ -138,6 +139,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements this.version = version; this.threadPool = threadPool; this.pageCacheRecycler = pageCacheRecycler; + this.circuitBreakerService = circuitBreakerService; this.networkService = networkService; String nodeName = Node.NODE_NAME_SETTING.get(settings); final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings); @@ -161,8 +163,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), TransportRequestOptions.EMPTY, v, false, true)); this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); - this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker, - keepAlive); + this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive); } public Version getVersion() { @@ -177,6 +178,14 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements return threadPool; } + public Supplier getInflightBreaker() { + return () -> circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS); + } + + public InboundHandler getInboundHandler() { + return inboundHandler; + } + @Override protected void doStart() { } @@ -692,10 +701,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements } } - public void inboundDecodeException(TcpChannel channel, Tuple tuple) { - onException(channel, tuple.v2()); - } - /** * Validates the first 6 bytes of the message header and returns the length of the message. If 6 bytes * are not available, it returns -1. diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java index 817d929e12c..07a0a1b3c12 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java @@ -20,8 +20,7 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; -import org.elasticsearch.common.breaker.CircuitBreaker; -import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.common.lease.Releasable; import java.io.IOException; import java.util.Set; @@ -36,24 +35,21 @@ public final class TcpTransportChannel implements TransportChannel { private final long requestId; private final Version version; private final Set features; - private final CircuitBreakerService breakerService; - private final long reservedBytes; private final boolean compressResponse; private final boolean isHandshake; + private final Releasable breakerRelease; TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version, - Set features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse, - boolean isHandshake) { + Set features, boolean compressResponse, boolean isHandshake, Releasable breakerRelease) { this.version = version; this.features = features; this.channel = channel; this.outboundHandler = outboundHandler; this.action = action; this.requestId = requestId; - this.breakerService = breakerService; - this.reservedBytes = reservedBytes; this.compressResponse = compressResponse; this.isHandshake = isHandshake; + this.breakerRelease = breakerRelease; } @Override @@ -84,7 +80,7 @@ public final class TcpTransportChannel implements TransportChannel { private void release(boolean isExceptionResponse) { if (released.compareAndSet(false, true)) { assert (releaseBy = new Exception()) != null; // easier to debug if it's already closed - breakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS).addWithoutBreaking(-reservedBytes); + breakerRelease.close(); } else if (isExceptionResponse == false) { // only fail if we are not sending an error - we might send the error triggered by the previous // sendResponse call diff --git a/server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java b/server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java index 3b2d6449b8d..82389c29a6c 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java @@ -20,6 +20,8 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.TestCircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.collect.Tuple; @@ -32,20 +34,33 @@ import org.junit.Before; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.function.Predicate; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.notNullValue; public class InboundAggregatorTests extends ESTestCase { private final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + private final String unBreakableAction = "non_breakable_action"; + private final String unknownAction = "unknown_action"; private InboundAggregator aggregator; + private TestCircuitBreaker circuitBreaker; @Before @Override public void setUp() throws Exception { super.setUp(); - aggregator = new InboundAggregator(); + Predicate requestCanTripBreaker = action -> { + if (unknownAction.equals(action)) { + throw new ActionNotFoundTransportException(action); + } else { + return unBreakableAction.equals(action) == false; + } + }; + circuitBreaker = new TestCircuitBreaker(); + aggregator = new InboundAggregator(() -> circuitBreaker, requestCanTripBreaker); } public void testInboundAggregation() throws IOException { @@ -95,7 +110,89 @@ public class InboundAggregatorTests extends ESTestCase { } } - public void testCancelAndCloseWillCloseContent() { + public void testInboundUnknownAction() throws IOException { + long requestId = randomNonNegativeLong(); + Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); + header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); + header.actionName = unknownAction; + // Initiate Message + aggregator.headerReceived(header); + + BytesArray bytes = new BytesArray(randomByteArrayOfLength(10)); + final ReleasableBytesReference content = ReleasableBytesReference.wrap(bytes); + aggregator.aggregate(content); + content.close(); + assertEquals(0, content.refCount()); + + // Signal EOS + InboundMessage aggregated = aggregator.finishAggregation(); + + assertThat(aggregated, notNullValue()); + assertTrue(aggregated.isShortCircuit()); + assertThat(aggregated.getException(), instanceOf(ActionNotFoundTransportException.class)); + } + + public void testCircuitBreak() throws IOException { + circuitBreaker.startBreaking(); + // Actions are breakable + Header breakableHeader = new Header(randomInt(), randomNonNegativeLong(), TransportStatus.setRequest((byte) 0), Version.CURRENT); + breakableHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); + breakableHeader.actionName = "action_name"; + // Initiate Message + aggregator.headerReceived(breakableHeader); + + BytesArray bytes = new BytesArray(randomByteArrayOfLength(10)); + final ReleasableBytesReference content1 = ReleasableBytesReference.wrap(bytes); + aggregator.aggregate(content1); + content1.close(); + + // Signal EOS + InboundMessage aggregated1 = aggregator.finishAggregation(); + + assertEquals(0, content1.refCount()); + assertThat(aggregated1, notNullValue()); + assertTrue(aggregated1.isShortCircuit()); + assertThat(aggregated1.getException(), instanceOf(CircuitBreakingException.class)); + + // Actions marked as unbreakable are not broken + Header unbreakableHeader = new Header(randomInt(), randomNonNegativeLong(), TransportStatus.setRequest((byte) 0), Version.CURRENT); + unbreakableHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); + unbreakableHeader.actionName = unBreakableAction; + // Initiate Message + aggregator.headerReceived(unbreakableHeader); + + final ReleasableBytesReference content2 = ReleasableBytesReference.wrap(bytes); + aggregator.aggregate(content2); + content2.close(); + + // Signal EOS + InboundMessage aggregated2 = aggregator.finishAggregation(); + + assertEquals(1, content2.refCount()); + assertThat(aggregated2, notNullValue()); + assertFalse(aggregated2.isShortCircuit()); + + // Handshakes are not broken + final byte handshakeStatus = TransportStatus.setHandshake(TransportStatus.setRequest((byte) 0)); + Header handshakeHeader = new Header(randomInt(), randomNonNegativeLong(), handshakeStatus, Version.CURRENT); + handshakeHeader.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); + handshakeHeader.actionName = "handshake"; + // Initiate Message + aggregator.headerReceived(handshakeHeader); + + final ReleasableBytesReference content3 = ReleasableBytesReference.wrap(bytes); + aggregator.aggregate(content3); + content3.close(); + + // Signal EOS + InboundMessage aggregated3 = aggregator.finishAggregation(); + + assertEquals(1, content3.refCount()); + assertThat(aggregated3, notNullValue()); + assertFalse(aggregated3.isShortCircuit()); + } + + public void testCloseWillCloseContent() { long requestId = randomNonNegativeLong(); Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); header.headers = new Tuple<>(Collections.emptyMap(), Collections.emptyMap()); @@ -121,11 +218,7 @@ public class InboundAggregatorTests extends ESTestCase { content2.close(); } - if (randomBoolean()) { - aggregator.cancelAggregation(); - } else { - aggregator.close(); - } + aggregator.close(); for (ReleasableBytesReference reference : references) { assertEquals(0, reference.refCount()); @@ -134,6 +227,13 @@ public class InboundAggregatorTests extends ESTestCase { public void testFinishAggregationWillFinishHeader() throws IOException { long requestId = randomNonNegativeLong(); + final String actionName; + final boolean unknownAction = randomBoolean(); + if (unknownAction) { + actionName = this.unknownAction; + } else { + actionName = "action_name"; + } Header header = new Header(randomInt(), requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT); // Initiate Message aggregator.headerReceived(header); @@ -141,18 +241,27 @@ public class InboundAggregatorTests extends ESTestCase { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { threadContext.writeTo(streamOutput); streamOutput.writeStringArray(new String[0]); - streamOutput.writeString("action_name"); + streamOutput.writeString(actionName); streamOutput.write(randomByteArrayOfLength(10)); - aggregator.aggregate(ReleasableBytesReference.wrap(streamOutput.bytes())); + final ReleasableBytesReference content = ReleasableBytesReference.wrap(streamOutput.bytes()); + aggregator.aggregate(content); + content.close(); // Signal EOS InboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertFalse(header.needsToReadVariableHeader()); - assertEquals("action_name", header.getActionName()); + assertEquals(actionName, header.getActionName()); + if (unknownAction) { + assertEquals(0, content.refCount()); + assertTrue(aggregated.isShortCircuit()); + } else { + assertEquals(1, content.refCount()); + assertFalse(aggregated.isShortCircuit()); + } } - } + } diff --git a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java index d8b840f970c..524af40442d 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java @@ -108,20 +108,13 @@ public class InboundDecoderTests extends ESTestCase { public void testDecodePreHeaderSizeVariableInt() throws IOException { // TODO: Can delete test on 9.0 - boolean isRequest = randomBoolean(); boolean isCompressed = randomBoolean(); String action = "test-request"; long requestId = randomNonNegativeLong(); final Version preHeaderVariableInt = Version.V_7_5_0; - OutboundMessage message; final String contentValue = randomAlphaOfLength(100); - if (isRequest) { - message = new OutboundMessage.Request(threadContext, new String[0], new TestRequest(contentValue), - preHeaderVariableInt, action, requestId, false, isCompressed); - } else { - message = new OutboundMessage.Response(threadContext, Collections.emptySet(), new TestResponse(contentValue), - preHeaderVariableInt, requestId, false, isCompressed); - } + final OutboundMessage message = new OutboundMessage.Request(threadContext, new String[0], new TestRequest(contentValue), + preHeaderVariableInt, action, requestId, true, isCompressed); final BytesReference totalBytes = message.serialize(new BytesStreamOutput()); int partialHeaderSize = TcpHeader.headerSize(preHeaderVariableInt); @@ -137,12 +130,8 @@ public class InboundDecoderTests extends ESTestCase { assertEquals(requestId, header.getRequestId()); assertEquals(preHeaderVariableInt, header.getVersion()); assertEquals(isCompressed, header.isCompressed()); - assertFalse(header.isHandshake()); - if (isRequest) { - assertTrue(header.isRequest()); - } else { - assertTrue(header.isResponse()); - } + assertTrue(header.isHandshake()); + assertTrue(header.isRequest()); assertTrue(header.needsToReadVariableHeader()); fragments.clear(); @@ -290,25 +279,13 @@ public class InboundDecoderTests extends ESTestCase { incompatibleVersion, action, requestId, false, true); final BytesReference bytes = message.serialize(new BytesStreamOutput()); - int totalHeaderSize = TcpHeader.headerSize(incompatibleVersion); InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final ArrayList fragments = new ArrayList<>(); final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes); - int bytesConsumed = decoder.decode(releasable1, fragments::add); - assertEquals(totalHeaderSize, bytesConsumed); + expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, fragments::add)); + // No bytes are retained assertEquals(1, releasable1.refCount()); - - final Header header = (Header) fragments.get(0); - assertEquals(requestId, header.getRequestId()); - assertEquals(incompatibleVersion, header.getVersion()); - fragments.clear(); - - final int remaining = bytes.length() - bytesConsumed; - final BytesReference bytes2 = bytes.slice(bytesConsumed, remaining); - final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2); - bytesConsumed = decoder.decode(releasable2, fragments::add); - assertEquals(remaining, bytesConsumed); } public void testEnsureVersionCompatibility() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java index 21bb8a1a425..97f1fb1183c 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java @@ -28,7 +28,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; @@ -62,8 +61,7 @@ public class InboundHandlerTests extends ESTestCase { TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage); OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool, BigArrays.NON_RECYCLING_INSTANCE); - final NoneCircuitBreakerService breaker = new NoneCircuitBreakerService(); - handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, breaker, handshaker, keepAlive); + handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive); } @After @@ -129,7 +127,7 @@ public class InboundHandlerTests extends ESTestCase { BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent)); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -150,7 +148,7 @@ public class InboundHandlerTests extends ESTestCase { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); Header responseHeader = new Header(fullRequestBytes.length() - 6, requestId, responseStatus, version); - InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent)); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); diff --git a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java index c67ff3251e9..e44340ab2b5 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java @@ -20,6 +20,11 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.breaker.TestCircuitBreaker; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.collect.Tuple; @@ -40,6 +45,8 @@ import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.LongSupplier; +import java.util.function.Predicate; +import java.util.function.Supplier; import static org.hamcrest.Matchers.instanceOf; @@ -60,34 +67,31 @@ public class InboundPipelineTests extends ESTestCase { final boolean isRequest = header.isRequest(); final long requestId = header.getRequestId(); final boolean isCompressed = header.isCompressed(); - if (isRequest) { + if (m.isShortCircuit()) { + actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), null); + } else if (isRequest) { final TestRequest request = new TestRequest(m.openOrGetStreamInput()); actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), request.value); } else { final TestResponse response = new TestResponse(m.openOrGetStreamInput()); actualData = new MessageData(version, requestId, isRequest, isCompressed, null, response.value); } - actual.add(new Tuple<>(actualData, null)); + actual.add(new Tuple<>(actualData, m.getException())); } catch (IOException e) { throw new AssertionError(e); } }; - final BiConsumer> errorHandler = (c, tuple) -> { - final Header header = tuple.v1(); - final MessageData actualData; - final Version version = header.getVersion(); - final boolean isRequest = header.isRequest(); - final long requestId = header.getRequestId(); - final boolean isCompressed = header.isCompressed(); - actualData = new MessageData(version, requestId, isRequest, isCompressed, null, null); - actual.add(new Tuple<>(actualData, tuple.v2())); - }; - final PageCacheRecycler recycler = PageCacheRecycler.NON_RECYCLING_INSTANCE; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); - final InboundPipeline pipeline = new InboundPipeline(Version.CURRENT, statsTracker, recycler, millisSupplier, messageHandler, - errorHandler); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final String breakThisAction = "break_this_action"; + final String actionName = "actionName"; + final Predicate canTripBreaker = breakThisAction::equals; + final TestCircuitBreaker circuitBreaker = new TestCircuitBreaker(); + circuitBreaker.startBreaking(); + final InboundAggregator aggregator = new InboundAggregator(() -> circuitBreaker, canTripBreaker); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); final FakeTcpChannel channel = new FakeTcpChannel(); final int iterations = randomIntBetween(100, 500); @@ -100,15 +104,7 @@ public class InboundPipelineTests extends ESTestCase { toRelease.clear(); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { while (streamOutput.size() < BYTE_THRESHOLD) { - final boolean invalidVersion = rarely(); - - String actionName = "actionName"; - final Version version; - if (invalidVersion) { - version = Version.CURRENT.minimumCompatibilityVersion().minimumCompatibilityVersion(); - } else { - version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); - } + final Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); final String value = randomAlphaOfLength(randomIntBetween(10, 200)); final boolean isRequest = randomBoolean(); final boolean isCompressed = randomBoolean(); @@ -119,21 +115,18 @@ public class InboundPipelineTests extends ESTestCase { OutboundMessage message; if (isRequest) { - if (invalidVersion) { - expectedExceptionClass = new IllegalStateException(); - messageData = new MessageData(version, requestId, true, isCompressed, null, null); + if (rarely()) { + messageData = new MessageData(version, requestId, true, isCompressed, breakThisAction, null); + message = new OutboundMessage.Request(threadContext, new String[0], new TestRequest(value), + version, breakThisAction, requestId, false, isCompressed); + expectedExceptionClass = new CircuitBreakingException("", CircuitBreaker.Durability.PERMANENT); } else { messageData = new MessageData(version, requestId, true, isCompressed, actionName, value); + message = new OutboundMessage.Request(threadContext, new String[0], new TestRequest(value), + version, actionName, requestId, false, isCompressed); } - message = new OutboundMessage.Request(threadContext, new String[0], new TestRequest(value), - version, actionName, requestId, false, isCompressed); } else { - if (invalidVersion) { - expectedExceptionClass = new IllegalStateException(); - messageData = new MessageData(version, requestId, false, isCompressed, null, null); - } else { - messageData = new MessageData(version, requestId, false, isCompressed, null, value); - } + messageData = new MessageData(version, requestId, false, isCompressed, null, value); message = new OutboundMessage.Response(threadContext, Collections.emptySet(), new TestResponse(value), version, requestId, false, isCompressed); } @@ -166,8 +159,8 @@ public class InboundPipelineTests extends ESTestCase { assertEquals(expectedMessageData.requestId, actualMessageData.requestId); assertEquals(expectedMessageData.isRequest, actualMessageData.isRequest); assertEquals(expectedMessageData.isCompressed, actualMessageData.isCompressed); - assertEquals(expectedMessageData.value, actualMessageData.value); assertEquals(expectedMessageData.actionName, actualMessageData.actionName); + assertEquals(expectedMessageData.value, actualMessageData.value); if (expectedTuple.v2() != null) { assertNotNull(actualTuple.v2()); assertThat(actualTuple.v2(), instanceOf(expectedTuple.v2().getClass())); @@ -184,14 +177,51 @@ public class InboundPipelineTests extends ESTestCase { } } - public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { - final PageCacheRecycler recycler = PageCacheRecycler.NON_RECYCLING_INSTANCE; + public void testDecodeExceptionIsPropagated() throws IOException { BiConsumer messageHandler = (c, m) -> {}; - BiConsumer> errorHandler = (c, e) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); - final InboundPipeline pipeline = new InboundPipeline(Version.CURRENT, statsTracker, recycler, millisSupplier, messageHandler, - errorHandler); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + + try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { + String actionName = "actionName"; + final Version invalidVersion = Version.CURRENT.minimumCompatibilityVersion().minimumCompatibilityVersion(); + final String value = randomAlphaOfLength(1000); + final boolean isRequest = randomBoolean(); + final long requestId = randomNonNegativeLong(); + + OutboundMessage message; + if (isRequest) { + message = new OutboundMessage.Request(threadContext, new String[0], new TestRequest(value), + invalidVersion, actionName, requestId, false, false); + } else { + message = new OutboundMessage.Response(threadContext, Collections.emptySet(), new TestResponse(value), + invalidVersion, requestId, false, false); + } + + final BytesReference reference = message.serialize(streamOutput); + try (ReleasableBytesReference releasable = ReleasableBytesReference.wrap(reference)) { + expectThrows(IllegalStateException.class, () -> pipeline.handleBytes(new FakeTcpChannel(), releasable)); + } + + // Pipeline cannot be reused after uncaught exception + final IllegalStateException ise = expectThrows(IllegalStateException.class, + () -> pipeline.handleBytes(new FakeTcpChannel(), ReleasableBytesReference.wrap(BytesArray.EMPTY))); + assertEquals("Pipeline state corrupted by uncaught exception", ise.getMessage()); + } + } + + public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { + BiConsumer messageHandler = (c, m) -> {}; + final StatsTracker statsTracker = new StatsTracker(); + final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { String actionName = "actionName"; diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java index ea5c639e8a7..e0533724519 100644 --- a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -23,6 +23,8 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -48,6 +50,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.LongSupplier; +import java.util.function.Predicate; +import java.util.function.Supplier; import static org.hamcrest.Matchers.instanceOf; @@ -59,7 +63,6 @@ public class OutboundHandlerTests extends ESTestCase { private final TransportRequestOptions options = TransportRequestOptions.EMPTY; private final AtomicReference> message = new AtomicReference<>(); private InboundPipeline pipeline; - private StatsTracker statsTracker; private OutboundHandler handler; private FakeTcpChannel channel; private DiscoveryNode node; @@ -71,11 +74,14 @@ public class OutboundHandlerTests extends ESTestCase { TransportAddress transportAddress = buildNewFakeTransportAddress(); node = new DiscoveryNode("", transportAddress, Version.CURRENT); String[] features = {feature1, feature2}; - statsTracker = new StatsTracker(); + StatsTracker statsTracker = new StatsTracker(); handler = new OutboundHandler("node", Version.CURRENT, features, statsTracker, threadPool, BigArrays.NON_RECYCLING_INSTANCE); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); - pipeline = new InboundPipeline(Version.CURRENT, new StatsTracker(), PageCacheRecycler.NON_RECYCLING_INSTANCE, millisSupplier, + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); + pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { Streams.copy(m.openOrGetStreamInput(), streamOutput); @@ -83,9 +89,7 @@ public class OutboundHandlerTests extends ESTestCase { } catch (IOException e) { throw new AssertionError(e); } - }, (c, t) -> { - throw new AssertionError(t.v2()); - }); + }); } @After diff --git a/test/framework/src/main/java/org/elasticsearch/common/breaker/TestCircuitBreaker.java b/test/framework/src/main/java/org/elasticsearch/common/breaker/TestCircuitBreaker.java new file mode 100644 index 00000000000..e2deffc52e7 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/common/breaker/TestCircuitBreaker.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common.breaker; + +import java.util.concurrent.atomic.AtomicBoolean; + +public class TestCircuitBreaker extends NoopCircuitBreaker { + + private final AtomicBoolean shouldBreak = new AtomicBoolean(false); + + public TestCircuitBreaker() { + super("test"); + } + + @Override + public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + if (shouldBreak.get()) { + throw new CircuitBreakingException("broken", getDurability()); + } + return 0; + } + + public void startBreaking() { + shouldBreak.set(true); + } + + public void stopBreaking() { + shouldBreak.set(false); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 251f3a5e6ff..def7ed4bf0c 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -1642,9 +1642,10 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { latch.await(); assertFalse(requestProcessed.get()); + } - service.acceptIncomingRequests(); - + service.acceptIncomingRequests(); + try (Transport.Connection connection = serviceA.openConnection(node, null)) { CountDownLatch latch2 = new CountDownLatch(1); serviceA.sendRequest(connection, "internal:action", new TestRequest(), TransportRequestOptions.EMPTY, new TransportResponseHandler() { @@ -2026,25 +2027,6 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { public void testTcpHandshake() { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); - try (MockTransportService service = buildService("TS_BAD", Version.CURRENT, Settings.EMPTY)) { - service.addMessageListener(new TransportMessageListener() { - @Override - public void onRequestReceived(long requestId, String action) { - if (TransportHandshaker.HANDSHAKE_ACTION_NAME.equals(action)) { - throw new ActionNotFoundTransportException(action); - } - } - }); - service.start(); - service.acceptIncomingRequests(); - // this acts like a node that doesn't have support for handshakes - DiscoveryNode node = - new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); - ConnectTransportException exception = expectThrows(ConnectTransportException.class, () -> serviceA.connectToNode(node)); - assertThat(exception.getCause(), instanceOf(IllegalStateException.class)); - assertEquals("handshake failed", exception.getCause().getMessage()); - } - ConnectionProfile connectionProfile = ConnectionProfile.buildDefaultConnectionProfile(Settings.EMPTY); try (TransportService service = buildService("TS_TPC", Version.CURRENT, Settings.EMPTY)) { DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index e68b265a672..19e561440b7 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -27,6 +27,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -53,7 +54,9 @@ import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectionProfile; +import org.elasticsearch.transport.InboundHandler; import org.elasticsearch.transport.InboundPipeline; +import org.elasticsearch.transport.StatsTracker; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpServerChannel; import org.elasticsearch.transport.TcpTransport; @@ -274,8 +277,12 @@ public class MockNioTransport extends TcpTransport { private MockTcpReadWriteHandler(MockSocketChannel channel, PageCacheRecycler recycler, TcpTransport transport) { this.channel = channel; final ThreadPool threadPool = transport.getThreadPool(); - this.pipeline = new InboundPipeline(transport.getVersion(), transport.getStatsTracker(), recycler, - threadPool::relativeTimeInMillis, transport::inboundMessage, transport::inboundDecodeException); + final Supplier breaker = transport.getInflightBreaker(); + final InboundHandler inboundHandler = transport.getInboundHandler(); + final Version version = transport.getVersion(); + final StatsTracker statsTracker = transport.getStatsTracker(); + this.pipeline = new InboundPipeline(version, statsTracker, recycler, threadPool::relativeTimeInMillis, breaker, + inboundHandler::getRequestHandler, transport::inboundMessage); } @Override